Unverified Commit 89300131 authored by Pedro Cuenca's avatar Pedro Cuenca Committed by GitHub
Browse files

Fix Flax `from_pt` (#1436)

Fix Flax `from_pt`.

It worked for models but not for pipelines.
Accidentally broken in #1107.
parent 6c56f050
...@@ -317,8 +317,8 @@ class FlaxDiffusionPipeline(ConfigMixin): ...@@ -317,8 +317,8 @@ class FlaxDiffusionPipeline(ConfigMixin):
allow_patterns = [os.path.join(k, "*") for k in folder_names] allow_patterns = [os.path.join(k, "*") for k in folder_names]
allow_patterns += [FLAX_WEIGHTS_NAME, SCHEDULER_CONFIG_NAME, CONFIG_NAME, cls.config_name] allow_patterns += [FLAX_WEIGHTS_NAME, SCHEDULER_CONFIG_NAME, CONFIG_NAME, cls.config_name]
# make sure we don't download PyTorch weights # make sure we don't download PyTorch weights, unless when using from_pt
ignore_patterns = "*.bin" ignore_patterns = "*.bin" if not from_pt else []
if cls != FlaxDiffusionPipeline: if cls != FlaxDiffusionPipeline:
requested_pipeline_class = cls.__name__ requested_pipeline_class = cls.__name__
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment