Unverified Commit 355bb641 authored by Jiwook Han's avatar Jiwook Han Committed by GitHub
Browse files

[doc] Fix some docstrings in `src/diffusers/training_utils.py` (#9606)



* refac: docstrings in training_utils.py

* fix: manual edits

* run make style

* add docstring at cast_training_params

---------
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent 92d2baf6
...@@ -36,8 +36,9 @@ if is_torch_npu_available(): ...@@ -36,8 +36,9 @@ if is_torch_npu_available():
def set_seed(seed: int): def set_seed(seed: int):
""" """
Args:
Helper function for reproducible behavior to set the seed in `random`, `numpy`, `torch`. Helper function for reproducible behavior to set the seed in `random`, `numpy`, `torch`.
Args:
seed (`int`): The seed to set. seed (`int`): The seed to set.
""" """
random.seed(seed) random.seed(seed)
...@@ -194,6 +195,13 @@ def unet_lora_state_dict(unet: UNet2DConditionModel) -> Dict[str, torch.Tensor]: ...@@ -194,6 +195,13 @@ def unet_lora_state_dict(unet: UNet2DConditionModel) -> Dict[str, torch.Tensor]:
def cast_training_params(model: Union[torch.nn.Module, List[torch.nn.Module]], dtype=torch.float32): def cast_training_params(model: Union[torch.nn.Module, List[torch.nn.Module]], dtype=torch.float32):
"""
Casts the training parameters of the model to the specified data type.
Args:
model: The PyTorch model whose parameters will be cast.
dtype: The data type to which the model parameters will be cast.
"""
if not isinstance(model, list): if not isinstance(model, list):
model = [model] model = [model]
for m in model: for m in model:
...@@ -225,7 +233,8 @@ def _set_state_dict_into_text_encoder( ...@@ -225,7 +233,8 @@ def _set_state_dict_into_text_encoder(
def compute_density_for_timestep_sampling( def compute_density_for_timestep_sampling(
weighting_scheme: str, batch_size: int, logit_mean: float = None, logit_std: float = None, mode_scale: float = None weighting_scheme: str, batch_size: int, logit_mean: float = None, logit_std: float = None, mode_scale: float = None
): ):
"""Compute the density for sampling the timesteps when doing SD3 training. """
Compute the density for sampling the timesteps when doing SD3 training.
Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528. Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528.
...@@ -244,7 +253,8 @@ def compute_density_for_timestep_sampling( ...@@ -244,7 +253,8 @@ def compute_density_for_timestep_sampling(
def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None): def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None):
"""Computes loss weighting scheme for SD3 training. """
Computes loss weighting scheme for SD3 training.
Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528. Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528.
...@@ -261,7 +271,9 @@ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None): ...@@ -261,7 +271,9 @@ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None):
def free_memory(): def free_memory():
"""Runs garbage collection. Then clears the cache of the available accelerator.""" """
Runs garbage collection. Then clears the cache of the available accelerator.
"""
gc.collect() gc.collect()
if torch.cuda.is_available(): if torch.cuda.is_available():
...@@ -494,7 +506,8 @@ class EMAModel: ...@@ -494,7 +506,8 @@ class EMAModel:
self.shadow_params = [p.pin_memory() for p in self.shadow_params] self.shadow_params = [p.pin_memory() for p in self.shadow_params]
def to(self, device=None, dtype=None, non_blocking=False) -> None: def to(self, device=None, dtype=None, non_blocking=False) -> None:
r"""Move internal buffers of the ExponentialMovingAverage to `device`. r"""
Move internal buffers of the ExponentialMovingAverage to `device`.
Args: Args:
device: like `device` argument to `torch.Tensor.to` device: like `device` argument to `torch.Tensor.to`
...@@ -528,23 +541,25 @@ class EMAModel: ...@@ -528,23 +541,25 @@ class EMAModel:
def store(self, parameters: Iterable[torch.nn.Parameter]) -> None: def store(self, parameters: Iterable[torch.nn.Parameter]) -> None:
r""" r"""
Saves the current parameters for restoring later.
Args: Args:
Save the current parameters for restoring later. parameters: Iterable of `torch.nn.Parameter`. The parameters to be temporarily stored.
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
temporarily stored.
""" """
self.temp_stored_params = [param.detach().cpu().clone() for param in parameters] self.temp_stored_params = [param.detach().cpu().clone() for param in parameters]
def restore(self, parameters: Iterable[torch.nn.Parameter]) -> None: def restore(self, parameters: Iterable[torch.nn.Parameter]) -> None:
r""" r"""
Args: Restore the parameters stored with the `store` method. Useful to validate the model with EMA parameters
Restore the parameters stored with the `store` method. Useful to validate the model with EMA parameters without: without: affecting the original optimization process. Store the parameters before the `copy_to()` method. After
affecting the original optimization process. Store the parameters before the `copy_to()` method. After
validation (or model saving), use this to restore the former parameters. validation (or model saving), use this to restore the former parameters.
Args:
parameters: Iterable of `torch.nn.Parameter`; the parameters to be parameters: Iterable of `torch.nn.Parameter`; the parameters to be
updated with the stored parameters. If `None`, the parameters with which this updated with the stored parameters. If `None`, the parameters with which this
`ExponentialMovingAverage` was initialized will be used. `ExponentialMovingAverage` was initialized will be used.
""" """
if self.temp_stored_params is None: if self.temp_stored_params is None:
raise RuntimeError("This ExponentialMovingAverage has no `store()`ed weights " "to `restore()`") raise RuntimeError("This ExponentialMovingAverage has no `store()`ed weights " "to `restore()`")
if self.foreach: if self.foreach:
...@@ -560,9 +575,10 @@ class EMAModel: ...@@ -560,9 +575,10 @@ class EMAModel:
def load_state_dict(self, state_dict: dict) -> None: def load_state_dict(self, state_dict: dict) -> None:
r""" r"""
Args:
Loads the ExponentialMovingAverage state. This method is used by accelerate during checkpointing to save the Loads the ExponentialMovingAverage state. This method is used by accelerate during checkpointing to save the
ema state dict. ema state dict.
Args:
state_dict (dict): EMA state. Should be an object returned state_dict (dict): EMA state. Should be an object returned
from a call to :meth:`state_dict`. from a call to :meth:`state_dict`.
""" """
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment