Unverified Commit 99029ab6 authored by Lysandre Debut's avatar Lysandre Debut Committed by GitHub
Browse files

Better error raised when cloned without lfs (#13401)

* Better error raised when cloned without lfs

* add from e
parent 18447c20
...@@ -21,6 +21,7 @@ from typing import Dict, Set, Tuple, Union ...@@ -21,6 +21,7 @@ from typing import Dict, Set, Tuple, Union
import flax.linen as nn import flax.linen as nn
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
import msgpack.exceptions
from flax.core.frozen_dict import FrozenDict, unfreeze from flax.core.frozen_dict import FrozenDict, unfreeze
from flax.serialization import from_bytes, to_bytes from flax.serialization import from_bytes, to_bytes
from flax.traverse_util import flatten_dict, unflatten_dict from flax.traverse_util import flatten_dict, unflatten_dict
...@@ -348,8 +349,19 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin): ...@@ -348,8 +349,19 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
with open(resolved_archive_file, "rb") as state_f: with open(resolved_archive_file, "rb") as state_f:
try: try:
state = from_bytes(cls, state_f.read()) state = from_bytes(cls, state_f.read())
except UnpicklingError: except (UnpicklingError, msgpack.exceptions.ExtraData) as e:
raise EnvironmentError(f"Unable to convert {archive_file} to Flax deserializable object. ") try:
with open(resolved_archive_file) as f:
if f.read().startswith("version"):
raise OSError(
"You seem to have cloned a repository without having git-lfs installed. Please install "
"git-lfs and run `git lfs install` followed by `git lfs pull` in the folder "
"you cloned."
)
else:
raise ValueError from e
except (UnicodeDecodeError, ValueError):
raise EnvironmentError(f"Unable to convert {archive_file} to Flax deserializable object. ")
# make sure all arrays are stored as jnp.arrays # make sure all arrays are stored as jnp.arrays
# NOTE: This is to prevent a bug this will be fixed in Flax >= v0.3.4: # NOTE: This is to prevent a bug this will be fixed in Flax >= v0.3.4:
# https://github.com/google/flax/issues/1261 # https://github.com/google/flax/issues/1261
......
...@@ -1334,11 +1334,22 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu ...@@ -1334,11 +1334,22 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
ignore_mismatched_sizes=ignore_mismatched_sizes, ignore_mismatched_sizes=ignore_mismatched_sizes,
_prefix=load_weight_prefix, _prefix=load_weight_prefix,
) )
except OSError: except OSError as e:
raise OSError( try:
"Unable to load weights from h5 file. " with open(resolved_archive_file) as f:
"If you tried to load a TF 2.0 model from a PyTorch checkpoint, please set from_pt=True. " if f.read().startswith("version"):
) raise OSError(
"You seem to have cloned a repository without having git-lfs installed. Please install "
"git-lfs and run `git lfs install` followed by `git lfs pull` in the folder "
"you cloned."
)
else:
raise ValueError from e
except (UnicodeDecodeError, ValueError):
raise OSError(
"Unable to load weights from h5 file. "
"If you tried to load a TF 2.0 model from a PyTorch checkpoint, please set from_pt=True. "
)
model(model.dummy_inputs) # Make sure restore ops are run model(model.dummy_inputs) # Make sure restore ops are run
......
...@@ -1285,12 +1285,23 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -1285,12 +1285,23 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
if state_dict is None: if state_dict is None:
try: try:
state_dict = torch.load(resolved_archive_file, map_location="cpu") state_dict = torch.load(resolved_archive_file, map_location="cpu")
except Exception: except Exception as e:
raise OSError( try:
f"Unable to load weights from pytorch checkpoint file for '{pretrained_model_name_or_path}' " with open(resolved_archive_file) as f:
f"at '{resolved_archive_file}'" if f.read().startswith("version"):
"If you tried to load a PyTorch model from a TF 2.0 checkpoint, please set from_tf=True. " raise OSError(
) "You seem to have cloned a repository without having git-lfs installed. Please install "
"git-lfs and run `git lfs install` followed by `git lfs pull` in the folder "
"you cloned."
)
else:
raise ValueError from e
except (UnicodeDecodeError, ValueError):
raise OSError(
f"Unable to load weights from pytorch checkpoint file for '{pretrained_model_name_or_path}' "
f"at '{resolved_archive_file}'"
"If you tried to load a PyTorch model from a TF 2.0 checkpoint, please set from_tf=True. "
)
# set dtype to instantiate the model under: # set dtype to instantiate the model under:
# 1. If torch_dtype is not None, we use that dtype # 1. If torch_dtype is not None, we use that dtype
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment