"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "4c54519e1a640f393ff790a72be38284d4253b45"
Unverified Commit a9fdb3de authored by Pedro Cuenca's avatar Pedro Cuenca Committed by GitHub
Browse files

Return Flax scheduler state (#601)

* Optionally return state in from_config.

Useful for Flax schedulers.

* has_state is now a property, make check more strict.

I don't check the class is `SchedulerMixin` to prevent circular
dependencies. It should be enough that the class name starts with "Flax"
the object declares it "has_state" and the "create_state" exists too.

* Use state in pipeline from_pretrained.

* Make style
parent e72f1a8a
...@@ -160,12 +160,19 @@ class ConfigMixin: ...@@ -160,12 +160,19 @@ class ConfigMixin:
if "dtype" in unused_kwargs: if "dtype" in unused_kwargs:
init_dict["dtype"] = unused_kwargs.pop("dtype") init_dict["dtype"] = unused_kwargs.pop("dtype")
# Return model and optionally state and/or unused_kwargs
model = cls(**init_dict) model = cls(**init_dict)
return_tuple = (model,)
# Flax schedulers have a state, so return it.
if cls.__name__.startswith("Flax") and hasattr(model, "create_state") and getattr(model, "has_state", False):
state = model.create_state()
return_tuple += (state,)
if return_unused_kwargs: if return_unused_kwargs:
return model, unused_kwargs return return_tuple + (unused_kwargs,)
else: else:
return model return return_tuple if len(return_tuple) > 1 else model
@classmethod @classmethod
def get_config_dict( def get_config_dict(
......
...@@ -437,8 +437,8 @@ class FlaxDiffusionPipeline(ConfigMixin): ...@@ -437,8 +437,8 @@ class FlaxDiffusionPipeline(ConfigMixin):
loaded_sub_model, loaded_params = load_method(loadable_folder, _do_init=False) loaded_sub_model, loaded_params = load_method(loadable_folder, _do_init=False)
params[name] = loaded_params params[name] = loaded_params
elif issubclass(class_obj, SchedulerMixin): elif issubclass(class_obj, SchedulerMixin):
loaded_sub_model = load_method(loadable_folder) loaded_sub_model, scheduler_state = load_method(loadable_folder)
params[name] = loaded_sub_model.create_state() params[name] = scheduler_state
else: else:
loaded_sub_model = load_method(loadable_folder) loaded_sub_model = load_method(loadable_folder)
......
...@@ -105,6 +105,10 @@ class FlaxDDIMScheduler(SchedulerMixin, ConfigMixin): ...@@ -105,6 +105,10 @@ class FlaxDDIMScheduler(SchedulerMixin, ConfigMixin):
stable diffusion. stable diffusion.
""" """
@property
def has_state(self):
return True
@register_to_config @register_to_config
def __init__( def __init__(
self, self,
......
...@@ -113,6 +113,10 @@ class FlaxPNDMScheduler(SchedulerMixin, ConfigMixin): ...@@ -113,6 +113,10 @@ class FlaxPNDMScheduler(SchedulerMixin, ConfigMixin):
stable diffusion. stable diffusion.
""" """
@property
def has_state(self):
return True
@register_to_config @register_to_config
def __init__( def __init__(
self, self,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment