Unverified Commit 486134e5 authored by Mishig Davaadorj's avatar Mishig Davaadorj Committed by GitHub
Browse files

Fix FlaxPretTrainedModel pt weights check (#19133)



* Fix FlaxPretTrainedModel pt weights check

* Update src/transformers/modeling_flax_utils.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* fix raise comment
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent e7fdfc72
...@@ -665,7 +665,7 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin): ...@@ -665,7 +665,7 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
archive_file = os.path.join(pretrained_model_name_or_path, FLAX_WEIGHTS_INDEX_NAME) archive_file = os.path.join(pretrained_model_name_or_path, FLAX_WEIGHTS_INDEX_NAME)
is_sharded = True is_sharded = True
# At this stage we don't have a weight file so we will raise an error. # At this stage we don't have a weight file so we will raise an error.
elif os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME): elif os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)):
raise EnvironmentError( raise EnvironmentError(
f"Error no file named {FLAX_WEIGHTS_NAME} found in directory {pretrained_model_name_or_path} " f"Error no file named {FLAX_WEIGHTS_NAME} found in directory {pretrained_model_name_or_path} "
"but there is a file for PyTorch weights. Use `from_pt=True` to load this model from those " "but there is a file for PyTorch weights. Use `from_pt=True` to load this model from those "
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment