".github/git@developer.sourcefind.cn:OpenDAS/torch-sparce.git" did not exist on "4e79c0db2143e565bd176a832c5d3332f7ebd4f5"
Unverified Commit f8100600 authored by Mishig Davaadorj's avatar Mishig Davaadorj Committed by GitHub
Browse files

Fix flax from_pretrained pytorch weight check (#603)

parent fb2fbab1
...@@ -307,7 +307,7 @@ class FlaxModelMixin: ...@@ -307,7 +307,7 @@ class FlaxModelMixin:
# Load model # Load model
if os.path.isdir(pretrained_model_name_or_path): if os.path.isdir(pretrained_model_name_or_path):
if from_pt: if from_pt:
if not os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME): if not os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)):
raise EnvironmentError( raise EnvironmentError(
f"Error no file named {WEIGHTS_NAME} found in directory {pretrained_model_name_or_path} " f"Error no file named {WEIGHTS_NAME} found in directory {pretrained_model_name_or_path} "
) )
...@@ -315,8 +315,8 @@ class FlaxModelMixin: ...@@ -315,8 +315,8 @@ class FlaxModelMixin:
elif os.path.isfile(os.path.join(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME)): elif os.path.isfile(os.path.join(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME)):
# Load from a Flax checkpoint # Load from a Flax checkpoint
model_file = os.path.join(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME) model_file = os.path.join(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME)
# At this stage we don't have a weight file so we will raise an error. # Check if pytorch weights exist instead
elif os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME): elif os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)):
raise EnvironmentError( raise EnvironmentError(
f"{WEIGHTS_NAME} file found in directory {pretrained_model_name_or_path}. Please load the model" f"{WEIGHTS_NAME} file found in directory {pretrained_model_name_or_path}. Please load the model"
" using `from_pt=True`." " using `from_pt=True`."
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment