Unverified Commit 89300131 authored by Pedro Cuenca's avatar Pedro Cuenca Committed by GitHub
Browse files

Fix Flax `from_pt` (#1436)

Fix Flax `from_pt`.

It worked for models but not for pipelines.
Accidentally broken in #1107.
parent 6c56f050
...@@ -332,7 +332,7 @@ class FlaxModelMixin: ...@@ -332,7 +332,7 @@ class FlaxModelMixin:
elif os.path.isfile(os.path.join(pretrained_path_with_subfolder, WEIGHTS_NAME)): elif os.path.isfile(os.path.join(pretrained_path_with_subfolder, WEIGHTS_NAME)):
raise EnvironmentError( raise EnvironmentError(
f"{WEIGHTS_NAME} file found in directory {pretrained_path_with_subfolder}. Please load the model" f"{WEIGHTS_NAME} file found in directory {pretrained_path_with_subfolder}. Please load the model"
" using `from_pt=True`." " using `from_pt=True`."
) )
else: else:
raise EnvironmentError( raise EnvironmentError(
......
...@@ -317,8 +317,8 @@ class FlaxDiffusionPipeline(ConfigMixin): ...@@ -317,8 +317,8 @@ class FlaxDiffusionPipeline(ConfigMixin):
allow_patterns = [os.path.join(k, "*") for k in folder_names] allow_patterns = [os.path.join(k, "*") for k in folder_names]
allow_patterns += [FLAX_WEIGHTS_NAME, SCHEDULER_CONFIG_NAME, CONFIG_NAME, cls.config_name] allow_patterns += [FLAX_WEIGHTS_NAME, SCHEDULER_CONFIG_NAME, CONFIG_NAME, cls.config_name]
# make sure we don't download PyTorch weights # make sure we don't download PyTorch weights, unless when using from_pt
ignore_patterns = "*.bin" ignore_patterns = "*.bin" if not from_pt else []
if cls != FlaxDiffusionPipeline: if cls != FlaxDiffusionPipeline:
requested_pipeline_class = cls.__name__ requested_pipeline_class = cls.__name__
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment