Unverified Commit 6eaebe82 authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[Utils] Adds `store()` and `restore()` methods to EMAModel (#2302)



* add store and restore() methods to EMAModel.

* Update src/diffusers/training_utils.py
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* make style with doc builder

* remove explicit listing.

* Apply suggestions from code review
Co-authored-by: default avatarWill Berman <wlbberman@gmail.com>

* Apply suggestions from code review
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* chore: better variable naming.

* better treatment of temp_stored_params
Co-authored-by: default avatarpatil-suraj <surajp815@gmail.com>

* make style

* remove temporary params from earth 🌎



* make fix-copies.

---------
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: default avatarWill Berman <wlbberman@gmail.com>
Co-authored-by: default avatarpatil-suraj <surajp815@gmail.com>
parent b214bb25
......@@ -115,7 +115,7 @@ class EMAModel:
deprecate("device", "1.0.0", deprecation_message, standard_warn=False)
self.to(device=kwargs["device"])
self.collected_params = None
self.temp_stored_params = None
self.decay = decay
self.min_decay = min_decay
......@@ -149,7 +149,6 @@ class EMAModel:
model = self.model_cls.from_config(self.model_config)
state_dict = self.state_dict()
state_dict.pop("shadow_params", None)
state_dict.pop("collected_params", None)
model.register_to_config(**state_dict)
self.copy_to(model.parameters())
......@@ -248,9 +247,35 @@ class EMAModel:
"inv_gamma": self.inv_gamma,
"power": self.power,
"shadow_params": self.shadow_params,
"collected_params": self.collected_params,
}
def store(self, parameters: Iterable[torch.nn.Parameter]) -> None:
r"""
Args:
Save the current parameters for restoring later.
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
temporarily stored.
"""
self.temp_stored_params = [param.detach().cpu().clone() for param in parameters]
def restore(self, parameters: Iterable[torch.nn.Parameter]) -> None:
r"""
Args:
Restore the parameters stored with the `store` method. Useful to validate the model with EMA parameters without:
affecting the original optimization process. Store the parameters before the `copy_to()` method. After
validation (or model saving), use this to restore the former parameters.
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
updated with the stored parameters. If `None`, the parameters with which this
`ExponentialMovingAverage` was initialized will be used.
"""
if self.temp_stored_params is None:
raise RuntimeError("This ExponentialMovingAverage has no `store()`ed weights " "to `restore()`")
for c_param, param in zip(self.temp_stored_params, parameters):
param.data.copy_(c_param.data)
# Better memory-wise.
self.temp_stored_params = None
def load_state_dict(self, state_dict: dict) -> None:
r"""
Args:
......@@ -297,12 +322,3 @@ class EMAModel:
raise ValueError("shadow_params must be a list")
if not all(isinstance(p, torch.Tensor) for p in self.shadow_params):
raise ValueError("shadow_params must all be Tensors")
self.collected_params = state_dict.get("collected_params", None)
if self.collected_params is not None:
if not isinstance(self.collected_params, list):
raise ValueError("collected_params must be a list")
if not all(isinstance(p, torch.Tensor) for p in self.collected_params):
raise ValueError("collected_params must all be Tensors")
if len(self.collected_params) != len(self.shadow_params):
raise ValueError("collected_params and shadow_params must have the same length")
......@@ -212,7 +212,7 @@ class StableDiffusionPipelineSafe(metaclass=DummyObject):
requires_backends(cls, ["torch", "transformers"])
class StableDiffusionSAGPipeline(metaclass=DummyObject):
class StableDiffusionPix2PixZeroPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
......@@ -227,7 +227,7 @@ class StableDiffusionSAGPipeline(metaclass=DummyObject):
requires_backends(cls, ["torch", "transformers"])
class StableDiffusionPix2PixZeroPipeline(metaclass=DummyObject):
class StableDiffusionSAGPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment