Unverified Commit 4b8880a3 authored by Mishig Davaadorj's avatar Mishig Davaadorj Committed by GitHub
Browse files

Make flax from_pretrained work with local subfolder (#608)

parent dd350c8a
...@@ -310,26 +310,31 @@ class FlaxModelMixin: ...@@ -310,26 +310,31 @@ class FlaxModelMixin:
) )
# Load model # Load model
if os.path.isdir(pretrained_model_name_or_path): pretrained_path_with_subfolder = (
pretrained_model_name_or_path
if subfolder is None
else os.path.join(pretrained_model_name_or_path, subfolder)
)
if os.path.isdir(pretrained_path_with_subfolder):
if from_pt: if from_pt:
if not os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)): if not os.path.isfile(os.path.join(pretrained_path_with_subfolder, WEIGHTS_NAME)):
raise EnvironmentError( raise EnvironmentError(
f"Error no file named {WEIGHTS_NAME} found in directory {pretrained_model_name_or_path} " f"Error no file named {WEIGHTS_NAME} found in directory {pretrained_path_with_subfolder} "
) )
model_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME) model_file = os.path.join(pretrained_path_with_subfolder, WEIGHTS_NAME)
elif os.path.isfile(os.path.join(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME)): elif os.path.isfile(os.path.join(pretrained_path_with_subfolder, FLAX_WEIGHTS_NAME)):
# Load from a Flax checkpoint # Load from a Flax checkpoint
model_file = os.path.join(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME) model_file = os.path.join(pretrained_path_with_subfolder, FLAX_WEIGHTS_NAME)
# Check if pytorch weights exist instead # Check if pytorch weights exist instead
elif os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)): elif os.path.isfile(os.path.join(pretrained_path_with_subfolder, WEIGHTS_NAME)):
raise EnvironmentError( raise EnvironmentError(
f"{WEIGHTS_NAME} file found in directory {pretrained_model_name_or_path}. Please load the model" f"{WEIGHTS_NAME} file found in directory {pretrained_path_with_subfolder}. Please load the model"
" using `from_pt=True`." " using `from_pt=True`."
) )
else: else:
raise EnvironmentError( raise EnvironmentError(
f"Error no file named {FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME} found in directory " f"Error no file named {FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME} found in directory "
f"{pretrained_model_name_or_path}." f"{pretrained_path_with_subfolder}."
) )
else: else:
try: try:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment