Unverified Commit 486134e5 authored by Mishig Davaadorj's avatar Mishig Davaadorj Committed by GitHub
Browse files

Fix FlaxPretTrainedModel pt weights check (#19133)



* Fix FlaxPretTrainedModel pt weights check

* Update src/transformers/modeling_flax_utils.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* fix raise comment
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent e7fdfc72
......@@ -665,7 +665,7 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
archive_file = os.path.join(pretrained_model_name_or_path, FLAX_WEIGHTS_INDEX_NAME)
is_sharded = True
# At this stage we don't have a weight file so we will raise an error.
elif os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME):
elif os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)):
raise EnvironmentError(
f"Error no file named {FLAX_WEIGHTS_NAME} found in directory {pretrained_model_name_or_path} "
"but there is a file for PyTorch weights. Use `from_pt=True` to load this model from those "
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment