Unverified Commit 355bb641 authored by Jiwook Han's avatar Jiwook Han Committed by GitHub
Browse files

[doc] Fix some docstrings in `src/diffusers/training_utils.py` (#9606)



* refac: docstrings in training_utils.py

* fix: manual edits

* run make style

* add docstring at cast_training_params

---------
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent 92d2baf6
......@@ -36,8 +36,9 @@ if is_torch_npu_available():
def set_seed(seed: int):
"""
Args:
Helper function for reproducible behavior to set the seed in `random`, `numpy`, `torch`.
Args:
seed (`int`): The seed to set.
"""
random.seed(seed)
......@@ -194,6 +195,13 @@ def unet_lora_state_dict(unet: UNet2DConditionModel) -> Dict[str, torch.Tensor]:
def cast_training_params(model: Union[torch.nn.Module, List[torch.nn.Module]], dtype=torch.float32):
"""
Casts the training parameters of the model to the specified data type.
Args:
model: The PyTorch model whose parameters will be cast.
dtype: The data type to which the model parameters will be cast.
"""
if not isinstance(model, list):
model = [model]
for m in model:
......@@ -225,7 +233,8 @@ def _set_state_dict_into_text_encoder(
def compute_density_for_timestep_sampling(
weighting_scheme: str, batch_size: int, logit_mean: float = None, logit_std: float = None, mode_scale: float = None
):
"""Compute the density for sampling the timesteps when doing SD3 training.
"""
Compute the density for sampling the timesteps when doing SD3 training.
Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528.
......@@ -244,7 +253,8 @@ def compute_density_for_timestep_sampling(
def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None):
"""Computes loss weighting scheme for SD3 training.
"""
Computes loss weighting scheme for SD3 training.
Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528.
......@@ -261,7 +271,9 @@ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None):
def free_memory():
"""Runs garbage collection. Then clears the cache of the available accelerator."""
"""
Runs garbage collection. Then clears the cache of the available accelerator.
"""
gc.collect()
if torch.cuda.is_available():
......@@ -494,7 +506,8 @@ class EMAModel:
self.shadow_params = [p.pin_memory() for p in self.shadow_params]
def to(self, device=None, dtype=None, non_blocking=False) -> None:
r"""Move internal buffers of the ExponentialMovingAverage to `device`.
r"""
Move internal buffers of the ExponentialMovingAverage to `device`.
Args:
device: like `device` argument to `torch.Tensor.to`
......@@ -528,23 +541,25 @@ class EMAModel:
def store(self, parameters: Iterable[torch.nn.Parameter]) -> None:
r"""
Saves the current parameters for restoring later.
Args:
Save the current parameters for restoring later.
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
temporarily stored.
parameters: Iterable of `torch.nn.Parameter`. The parameters to be temporarily stored.
"""
self.temp_stored_params = [param.detach().cpu().clone() for param in parameters]
def restore(self, parameters: Iterable[torch.nn.Parameter]) -> None:
r"""
Args:
Restore the parameters stored with the `store` method. Useful to validate the model with EMA parameters without:
affecting the original optimization process. Store the parameters before the `copy_to()` method. After
Restore the parameters stored with the `store` method. Useful to validate the model with EMA parameters
without: affecting the original optimization process. Store the parameters before the `copy_to()` method. After
validation (or model saving), use this to restore the former parameters.
Args:
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
updated with the stored parameters. If `None`, the parameters with which this
`ExponentialMovingAverage` was initialized will be used.
"""
if self.temp_stored_params is None:
raise RuntimeError("This ExponentialMovingAverage has no `store()`ed weights " "to `restore()`")
if self.foreach:
......@@ -560,9 +575,10 @@ class EMAModel:
def load_state_dict(self, state_dict: dict) -> None:
r"""
Args:
Loads the ExponentialMovingAverage state. This method is used by accelerate during checkpointing to save the
ema state dict.
Args:
state_dict (dict): EMA state. Should be an object returned
from a call to :meth:`state_dict`.
"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment