Unverified Commit f8100600 authored by Mishig Davaadorj's avatar Mishig Davaadorj Committed by GitHub
Browse files

Fix flax from_pretrained pytorch weight check (#603)

parent fb2fbab1
......@@ -307,7 +307,7 @@ class FlaxModelMixin:
# Load model
if os.path.isdir(pretrained_model_name_or_path):
if from_pt:
if not os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME):
if not os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)):
raise EnvironmentError(
f"Error no file named {WEIGHTS_NAME} found in directory {pretrained_model_name_or_path} "
)
......@@ -315,8 +315,8 @@ class FlaxModelMixin:
elif os.path.isfile(os.path.join(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME)):
# Load from a Flax checkpoint
model_file = os.path.join(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME)
# At this stage we don't have a weight file so we will raise an error.
elif os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME):
# Check if pytorch weights exist instead
elif os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)):
raise EnvironmentError(
f"{WEIGHTS_NAME} file found in directory {pretrained_model_name_or_path}. Please load the model"
" using `from_pt=True`."
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment