Unverified Commit 90cddfa8 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

Add variant to transformers (#21332)

* Bump onnx in /examples/research_projects/decision_transformer

Bumps [onnx](https://github.com/onnx/onnx) from 1.11.0 to 1.13.0.
- [Release notes](https://github.com/onnx/onnx/releases)
- [Changelog](https://github.com/onnx/onnx/blob/main/docs/Changelog.md)
- [Commits](https://github.com/onnx/onnx/compare/v1.11.0...v1.13.0

)

---
updated-dependencies:
- dependency-name: onnx
  dependency-type: direct:production
...
Signed-off-by: default avatardependabot[bot] <support@github.com>

* adapt

* finish

* Update examples/research_projects/decision_transformer/requirements.txt

* up

* add tests

* Apply suggestions from code review
Co-authored-by: default avatarLucain <lucainp@gmail.com>
Co-authored-by: default avatarPedro Cuenca <pedro@huggingface.co>

* fix test

---------
Signed-off-by: default avatardependabot[bot] <support@github.com>
Co-authored-by: default avatardependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: default avatarLucain <lucainp@gmail.com>
Co-authored-by: default avatarPedro Cuenca <pedro@huggingface.co>
parent bc44e947
...@@ -667,6 +667,15 @@ def _load_state_dict_into_meta_model( ...@@ -667,6 +667,15 @@ def _load_state_dict_into_meta_model(
return error_msgs, offload_index, state_dict_index return error_msgs, offload_index, state_dict_index
def _add_variant(weights_name: str, variant: Optional[str] = None) -> str:
if variant is not None:
splits = weights_name.split(".")
splits = splits[:-1] + [variant] + splits[-1:]
weights_name = ".".join(splits)
return weights_name
class ModuleUtilsMixin: class ModuleUtilsMixin:
""" """
A few utilities for `torch.nn.Modules`, to be used as a mixin. A few utilities for `torch.nn.Modules`, to be used as a mixin.
...@@ -1567,6 +1576,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -1567,6 +1576,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
push_to_hub: bool = False, push_to_hub: bool = False,
max_shard_size: Union[int, str] = "10GB", max_shard_size: Union[int, str] = "10GB",
safe_serialization: bool = False, safe_serialization: bool = False,
variant: Optional[str] = None,
**kwargs, **kwargs,
): ):
""" """
...@@ -1604,6 +1614,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -1604,6 +1614,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
safe_serialization (`bool`, *optional*, defaults to `False`): safe_serialization (`bool`, *optional*, defaults to `False`):
Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`). Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).
variant (`str`, *optional*):
If specified, weights are saved in the format pytorch_model.<variant>.bin.
kwargs: kwargs:
Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method. Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
...@@ -1675,6 +1687,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -1675,6 +1687,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
# Shard the model if it is too big. # Shard the model if it is too big.
weights_name = SAFE_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME weights_name = SAFE_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME
weights_name = _add_variant(weights_name, variant)
shards, index = shard_checkpoint(state_dict, max_shard_size=max_shard_size, weights_name=weights_name) shards, index = shard_checkpoint(state_dict, max_shard_size=max_shard_size, weights_name=weights_name)
# Clean the folder from a previous save # Clean the folder from a previous save
...@@ -1701,10 +1715,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -1701,10 +1715,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
save_function(shard, os.path.join(save_directory, shard_file)) save_function(shard, os.path.join(save_directory, shard_file))
if index is None: if index is None:
logger.info(f"Model weights saved in {os.path.join(save_directory, WEIGHTS_NAME)}") path_to_weights = os.path.join(save_directory, _add_variant(WEIGHTS_NAME, variant))
logger.info(f"Model weights saved in {path_to_weights}")
else: else:
save_index_file = SAFE_WEIGHTS_INDEX_NAME if safe_serialization else WEIGHTS_INDEX_NAME save_index_file = SAFE_WEIGHTS_INDEX_NAME if safe_serialization else WEIGHTS_INDEX_NAME
save_index_file = os.path.join(save_directory, save_index_file) save_index_file = os.path.join(save_directory, _add_variant(save_index_file, variant))
# Save the index as well # Save the index as well
with open(save_index_file, "w", encoding="utf-8") as f: with open(save_index_file, "w", encoding="utf-8") as f:
content = json.dumps(index, indent=2, sort_keys=True) + "\n" content = json.dumps(index, indent=2, sort_keys=True) + "\n"
...@@ -1931,6 +1946,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -1931,6 +1946,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
subfolder (`str`, *optional*, defaults to `""`): subfolder (`str`, *optional*, defaults to `""`):
In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can
specify the folder name here. specify the folder name here.
variant (`str`, *optional*):
If specified load weights from `variant` filename, *e.g.* pytorch_model.<variant>.bin. `variant` is
ignored when using `from_tf` or `from_flax`.
kwargs (remaining dictionary of keyword arguments, *optional*): kwargs (remaining dictionary of keyword arguments, *optional*):
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g., Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
...@@ -2017,6 +2035,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -2017,6 +2035,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
load_in_8bit_skip_modules = kwargs.pop("load_in_8bit_skip_modules", None) load_in_8bit_skip_modules = kwargs.pop("load_in_8bit_skip_modules", None)
subfolder = kwargs.pop("subfolder", "") subfolder = kwargs.pop("subfolder", "")
commit_hash = kwargs.pop("_commit_hash", None) commit_hash = kwargs.pop("_commit_hash", None)
variant = kwargs.pop("variant", None)
if trust_remote_code is True: if trust_remote_code is True:
logger.warning( logger.warning(
...@@ -2132,42 +2151,57 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -2132,42 +2151,57 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
# Load from a Flax checkpoint in priority if from_flax # Load from a Flax checkpoint in priority if from_flax
archive_file = os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_NAME) archive_file = os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_NAME)
elif is_safetensors_available() and os.path.isfile( elif is_safetensors_available() and os.path.isfile(
os.path.join(pretrained_model_name_or_path, subfolder, SAFE_WEIGHTS_NAME) os.path.join(pretrained_model_name_or_path, subfolder, _add_variant(SAFE_WEIGHTS_NAME, variant))
): ):
# Load from a safetensors checkpoint # Load from a safetensors checkpoint
archive_file = os.path.join(pretrained_model_name_or_path, subfolder, SAFE_WEIGHTS_NAME) archive_file = os.path.join(
pretrained_model_name_or_path, subfolder, _add_variant(SAFE_WEIGHTS_NAME, variant)
)
elif is_safetensors_available() and os.path.isfile( elif is_safetensors_available() and os.path.isfile(
os.path.join(pretrained_model_name_or_path, subfolder, SAFE_WEIGHTS_INDEX_NAME) os.path.join(
pretrained_model_name_or_path, subfolder, _add_variant(SAFE_WEIGHTS_INDEX_NAME, variant)
)
): ):
# Load from a sharded safetensors checkpoint # Load from a sharded safetensors checkpoint
archive_file = os.path.join(pretrained_model_name_or_path, subfolder, SAFE_WEIGHTS_INDEX_NAME) archive_file = os.path.join(
pretrained_model_name_or_path, subfolder, _add_variant(SAFE_WEIGHTS_INDEX_NAME, variant)
)
is_sharded = True is_sharded = True
elif os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, WEIGHTS_NAME)): elif os.path.isfile(
os.path.join(pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_NAME, variant))
):
# Load from a PyTorch checkpoint # Load from a PyTorch checkpoint
archive_file = os.path.join(pretrained_model_name_or_path, subfolder, WEIGHTS_NAME) archive_file = os.path.join(
elif os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, WEIGHTS_INDEX_NAME)): pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_NAME, variant)
)
elif os.path.isfile(
os.path.join(pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_INDEX_NAME, variant))
):
# Load from a sharded PyTorch checkpoint # Load from a sharded PyTorch checkpoint
archive_file = os.path.join(pretrained_model_name_or_path, subfolder, WEIGHTS_INDEX_NAME) archive_file = os.path.join(
pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_INDEX_NAME, variant)
)
is_sharded = True is_sharded = True
# At this stage we don't have a weight file so we will raise an error. # At this stage we don't have a weight file so we will raise an error.
elif os.path.isfile( elif os.path.isfile(
os.path.join(pretrained_model_name_or_path, subfolder, TF_WEIGHTS_NAME + ".index") os.path.join(pretrained_model_name_or_path, subfolder, TF_WEIGHTS_NAME + ".index")
) or os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, TF2_WEIGHTS_NAME)): ) or os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, TF2_WEIGHTS_NAME)):
raise EnvironmentError( raise EnvironmentError(
f"Error no file named {WEIGHTS_NAME} found in directory {pretrained_model_name_or_path} but " f"Error no file named {_add_variant(WEIGHTS_NAME, variant)} found in directory"
"there is a file for TensorFlow weights. Use `from_tf=True` to load this model from those " f" {pretrained_model_name_or_path} but there is a file for TensorFlow weights. Use"
"weights." " `from_tf=True` to load this model from those weights."
) )
elif os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_NAME)): elif os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_NAME)):
raise EnvironmentError( raise EnvironmentError(
f"Error no file named {WEIGHTS_NAME} found in directory {pretrained_model_name_or_path} but " f"Error no file named {_add_variant(WEIGHTS_NAME, variant)} found in directory"
"there is a file for Flax weights. Use `from_flax=True` to load this model from those " f" {pretrained_model_name_or_path} but there is a file for Flax weights. Use `from_flax=True`"
"weights." " to load this model from those weights."
) )
else: else:
raise EnvironmentError( raise EnvironmentError(
f"Error no file named {WEIGHTS_NAME}, {TF2_WEIGHTS_NAME}, {TF_WEIGHTS_NAME + '.index'} or " f"Error no file named {_add_variant(WEIGHTS_NAME, variant)}, {TF2_WEIGHTS_NAME},"
f"{FLAX_WEIGHTS_NAME} found in directory {pretrained_model_name_or_path}." f" {TF_WEIGHTS_NAME + '.index'} or {FLAX_WEIGHTS_NAME} found in directory"
f" {pretrained_model_name_or_path}."
) )
elif os.path.isfile(os.path.join(subfolder, pretrained_model_name_or_path)): elif os.path.isfile(os.path.join(subfolder, pretrained_model_name_or_path)):
archive_file = pretrained_model_name_or_path archive_file = pretrained_model_name_or_path
...@@ -2190,9 +2224,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -2190,9 +2224,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
elif from_flax: elif from_flax:
filename = FLAX_WEIGHTS_NAME filename = FLAX_WEIGHTS_NAME
elif is_safetensors_available(): elif is_safetensors_available():
filename = SAFE_WEIGHTS_NAME filename = _add_variant(SAFE_WEIGHTS_NAME, variant)
else: else:
filename = WEIGHTS_NAME filename = _add_variant(WEIGHTS_NAME, variant)
try: try:
# Load from URL or cache if already cached # Load from URL or cache if already cached
...@@ -2213,23 +2247,27 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -2213,23 +2247,27 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
# Since we set _raise_exceptions_for_missing_entries=False, we don't get an exception but a None # Since we set _raise_exceptions_for_missing_entries=False, we don't get an exception but a None
# result when internet is up, the repo and revision exist, but the file does not. # result when internet is up, the repo and revision exist, but the file does not.
if resolved_archive_file is None and filename == SAFE_WEIGHTS_NAME: if resolved_archive_file is None and filename == _add_variant(SAFE_WEIGHTS_NAME, variant):
# Maybe the checkpoint is sharded, we try to grab the index name in this case. # Maybe the checkpoint is sharded, we try to grab the index name in this case.
resolved_archive_file = cached_file( resolved_archive_file = cached_file(
pretrained_model_name_or_path, SAFE_WEIGHTS_INDEX_NAME, **cached_file_kwargs pretrained_model_name_or_path,
_add_variant(SAFE_WEIGHTS_INDEX_NAME, variant),
**cached_file_kwargs,
) )
if resolved_archive_file is not None: if resolved_archive_file is not None:
is_sharded = True is_sharded = True
else: else:
# This repo has no safetensors file of any kind, we switch to PyTorch. # This repo has no safetensors file of any kind, we switch to PyTorch.
filename = WEIGHTS_NAME filename = _add_variant(WEIGHTS_NAME, variant)
resolved_archive_file = cached_file( resolved_archive_file = cached_file(
pretrained_model_name_or_path, WEIGHTS_NAME, **cached_file_kwargs pretrained_model_name_or_path, filename, **cached_file_kwargs
) )
if resolved_archive_file is None and filename == WEIGHTS_NAME: if resolved_archive_file is None and filename == _add_variant(WEIGHTS_NAME, variant):
# Maybe the checkpoint is sharded, we try to grab the index name in this case. # Maybe the checkpoint is sharded, we try to grab the index name in this case.
resolved_archive_file = cached_file( resolved_archive_file = cached_file(
pretrained_model_name_or_path, WEIGHTS_INDEX_NAME, **cached_file_kwargs pretrained_model_name_or_path,
_add_variant(WEIGHTS_INDEX_NAME, variant),
**cached_file_kwargs,
) )
if resolved_archive_file is not None: if resolved_archive_file is not None:
is_sharded = True is_sharded = True
...@@ -2244,19 +2282,28 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -2244,19 +2282,28 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
if has_file(pretrained_model_name_or_path, TF2_WEIGHTS_NAME, **has_file_kwargs): if has_file(pretrained_model_name_or_path, TF2_WEIGHTS_NAME, **has_file_kwargs):
raise EnvironmentError( raise EnvironmentError(
f"{pretrained_model_name_or_path} does not appear to have a file named" f"{pretrained_model_name_or_path} does not appear to have a file named"
f" {WEIGHTS_NAME} but there is a file for TensorFlow weights. Use `from_tf=True` to" f" {_add_variant(WEIGHTS_NAME, variant)} but there is a file for TensorFlow weights."
" load this model from those weights." " Use `from_tf=True` to load this model from those weights."
) )
elif has_file(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME, **has_file_kwargs): elif has_file(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME, **has_file_kwargs):
raise EnvironmentError( raise EnvironmentError(
f"{pretrained_model_name_or_path} does not appear to have a file named" f"{pretrained_model_name_or_path} does not appear to have a file named"
f" {WEIGHTS_NAME} but there is a file for Flax weights. Use `from_flax=True` to load" f" {_add_variant(WEIGHTS_NAME, variant)} but there is a file for Flax weights. Use"
" this model from those weights." " `from_flax=True` to load this model from those weights."
)
elif variant is not None and has_file(
pretrained_model_name_or_path, WEIGHTS_NAME, **has_file_kwargs
):
raise EnvironmentError(
f"{pretrained_model_name_or_path} does not appear to have a file named"
f" {_add_variant(WEIGHTS_NAME, variant)} but there is a file without the variant"
f" {variant}. Use `variant=None` to load this model from those weights."
) )
else: else:
raise EnvironmentError( raise EnvironmentError(
f"{pretrained_model_name_or_path} does not appear to have a file named {WEIGHTS_NAME}," f"{pretrained_model_name_or_path} does not appear to have a file named"
f" {TF2_WEIGHTS_NAME}, {TF_WEIGHTS_NAME} or {FLAX_WEIGHTS_NAME}." f" {_add_variant(WEIGHTS_NAME, variant)}, {TF2_WEIGHTS_NAME}, {TF_WEIGHTS_NAME} or"
f" {FLAX_WEIGHTS_NAME}."
) )
except EnvironmentError: except EnvironmentError:
# Raise any environment error raise by `cached_file`. It will have a helpful error message adapted # Raise any environment error raise by `cached_file`. It will have a helpful error message adapted
...@@ -2268,8 +2315,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -2268,8 +2315,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it" f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it"
" from 'https://huggingface.co/models', make sure you don't have a local directory with the" " from 'https://huggingface.co/models', make sure you don't have a local directory with the"
f" same name. Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a" f" same name. Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a"
f" directory containing a file named {WEIGHTS_NAME}, {TF2_WEIGHTS_NAME}, {TF_WEIGHTS_NAME} or" f" directory containing a file named {_add_variant(WEIGHTS_NAME, variant)},"
f" {FLAX_WEIGHTS_NAME}." f" {TF2_WEIGHTS_NAME}, {TF_WEIGHTS_NAME} or {FLAX_WEIGHTS_NAME}."
) )
if is_local: if is_local:
......
...@@ -2958,6 +2958,138 @@ class ModelUtilsTest(TestCasePlus): ...@@ -2958,6 +2958,138 @@ class ModelUtilsTest(TestCasePlus):
for p1, p2 in zip(model.parameters(), ref_model.parameters()): for p1, p2 in zip(model.parameters(), ref_model.parameters()):
self.assertTrue(torch.allclose(p1, p2)) self.assertTrue(torch.allclose(p1, p2))
def test_checkpoint_variant_local(self):
model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
with tempfile.TemporaryDirectory() as tmp_dir:
model.save_pretrained(tmp_dir, variant="v2")
weights_name = ".".join(WEIGHTS_NAME.split(".")[:-1] + ["v2"] + ["bin"])
weights_file = os.path.join(tmp_dir, weights_name)
self.assertTrue(os.path.isfile(weights_file))
self.assertFalse(os.path.isfile(os.path.join(tmp_dir, WEIGHTS_NAME)))
with self.assertRaises(EnvironmentError):
_ = BertModel.from_pretrained(tmp_dir)
new_model = BertModel.from_pretrained(tmp_dir, variant="v2")
for p1, p2 in zip(model.parameters(), new_model.parameters()):
self.assertTrue(torch.allclose(p1, p2))
def test_checkpoint_variant_local_sharded(self):
model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
with tempfile.TemporaryDirectory() as tmp_dir:
model.save_pretrained(tmp_dir, variant="v2", max_shard_size="50kB")
weights_index_name = ".".join(WEIGHTS_INDEX_NAME.split(".")[:-1] + ["v2"] + ["json"])
weights_index_file = os.path.join(tmp_dir, weights_index_name)
self.assertTrue(os.path.isfile(weights_index_file))
self.assertFalse(os.path.isfile(os.path.join(tmp_dir, WEIGHTS_INDEX_NAME)))
for i in range(1, 6):
weights_name = ".".join(WEIGHTS_NAME.split(".")[:-1] + [f"v2-0000{i}-of-00006"] + ["bin"])
weights_name_file = os.path.join(tmp_dir, weights_name)
self.assertTrue(os.path.isfile(weights_name_file))
with self.assertRaises(EnvironmentError):
_ = BertModel.from_pretrained(tmp_dir)
new_model = BertModel.from_pretrained(tmp_dir, variant="v2")
for p1, p2 in zip(model.parameters(), new_model.parameters()):
self.assertTrue(torch.allclose(p1, p2))
@require_safetensors
def test_checkpoint_variant_local_safe(self):
model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
with tempfile.TemporaryDirectory() as tmp_dir:
model.save_pretrained(tmp_dir, variant="v2", safe_serialization=True)
weights_name = ".".join(SAFE_WEIGHTS_NAME.split(".")[:-1] + ["v2"] + ["safetensors"])
weights_file = os.path.join(tmp_dir, weights_name)
self.assertTrue(os.path.isfile(weights_file))
self.assertFalse(os.path.isfile(os.path.join(tmp_dir, SAFE_WEIGHTS_NAME)))
with self.assertRaises(EnvironmentError):
_ = BertModel.from_pretrained(tmp_dir)
new_model = BertModel.from_pretrained(tmp_dir, variant="v2")
for p1, p2 in zip(model.parameters(), new_model.parameters()):
self.assertTrue(torch.allclose(p1, p2))
@require_safetensors
def test_checkpoint_variant_local_sharded_safe(self):
model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
with tempfile.TemporaryDirectory() as tmp_dir:
model.save_pretrained(tmp_dir, variant="v2", max_shard_size="50kB", safe_serialization=True)
weights_index_name = ".".join(SAFE_WEIGHTS_INDEX_NAME.split(".")[:-1] + ["v2"] + ["json"])
weights_index_file = os.path.join(tmp_dir, weights_index_name)
self.assertTrue(os.path.isfile(weights_index_file))
self.assertFalse(os.path.isfile(os.path.join(tmp_dir, SAFE_WEIGHTS_INDEX_NAME)))
for i in range(1, 6):
weights_name = ".".join(SAFE_WEIGHTS_NAME.split(".")[:-1] + [f"v2-0000{i}-of-00006"] + ["safetensors"])
weights_name_file = os.path.join(tmp_dir, weights_name)
self.assertTrue(os.path.isfile(weights_name_file))
with self.assertRaises(EnvironmentError):
_ = BertModel.from_pretrained(tmp_dir)
new_model = BertModel.from_pretrained(tmp_dir, variant="v2")
for p1, p2 in zip(model.parameters(), new_model.parameters()):
self.assertTrue(torch.allclose(p1, p2))
def test_checkpoint_variant_hub(self):
with tempfile.TemporaryDirectory() as tmp_dir:
with self.assertRaises(EnvironmentError):
_ = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert-variant", cache_dir=tmp_dir)
model = BertModel.from_pretrained(
"hf-internal-testing/tiny-random-bert-variant", cache_dir=tmp_dir, variant="v2"
)
self.assertIsNotNone(model)
def test_checkpoint_variant_hub_sharded(self):
with tempfile.TemporaryDirectory() as tmp_dir:
with self.assertRaises(EnvironmentError):
_ = BertModel.from_pretrained(
"hf-internal-testing/tiny-random-bert-variant-sharded", cache_dir=tmp_dir
)
model = BertModel.from_pretrained(
"hf-internal-testing/tiny-random-bert-variant-sharded", cache_dir=tmp_dir, variant="v2"
)
self.assertIsNotNone(model)
@require_safetensors
def test_checkpoint_variant_hub_safe(self):
with tempfile.TemporaryDirectory() as tmp_dir:
with self.assertRaises(EnvironmentError):
_ = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert-variant-safe", cache_dir=tmp_dir)
model = BertModel.from_pretrained(
"hf-internal-testing/tiny-random-bert-variant-safe", cache_dir=tmp_dir, variant="v2"
)
self.assertIsNotNone(model)
@require_safetensors
def test_checkpoint_variant_hub_sharded_safe(self):
with tempfile.TemporaryDirectory() as tmp_dir:
with self.assertRaises(EnvironmentError):
_ = BertModel.from_pretrained(
"hf-internal-testing/tiny-random-bert-variant-sharded-safe", cache_dir=tmp_dir
)
model = BertModel.from_pretrained(
"hf-internal-testing/tiny-random-bert-variant-sharded-safe", cache_dir=tmp_dir, variant="v2"
)
self.assertIsNotNone(model)
@require_accelerate @require_accelerate
def test_from_pretrained_low_cpu_mem_usage_functional(self): def test_from_pretrained_low_cpu_mem_usage_functional(self):
# test that we can use `from_pretrained(..., low_cpu_mem_usage=True)` with normal and # test that we can use `from_pretrained(..., low_cpu_mem_usage=True)` with normal and
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment