Unverified Commit 89300131 authored by Pedro Cuenca's avatar Pedro Cuenca Committed by GitHub
Browse files

Fix Flax `from_pt` (#1436)

Fix Flax `from_pt`.

It worked for models but not for pipelines.
Accidentally broken in #1107.
parent 6c56f050
......@@ -332,7 +332,7 @@ class FlaxModelMixin:
elif os.path.isfile(os.path.join(pretrained_path_with_subfolder, WEIGHTS_NAME)):
raise EnvironmentError(
f"{WEIGHTS_NAME} file found in directory {pretrained_path_with_subfolder}. Please load the model"
" using `from_pt=True`."
" using `from_pt=True`."
)
else:
raise EnvironmentError(
......
......@@ -317,8 +317,8 @@ class FlaxDiffusionPipeline(ConfigMixin):
allow_patterns = [os.path.join(k, "*") for k in folder_names]
allow_patterns += [FLAX_WEIGHTS_NAME, SCHEDULER_CONFIG_NAME, CONFIG_NAME, cls.config_name]
# make sure we don't download PyTorch weights
ignore_patterns = "*.bin"
# make sure we don't download PyTorch weights, unless when using from_pt
ignore_patterns = "*.bin" if not from_pt else []
if cls != FlaxDiffusionPipeline:
requested_pipeline_class = cls.__name__
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment