Unverified Commit a3cc641f authored by RogerSinghChugh's avatar RogerSinghChugh Committed by GitHub
Browse files

Refac training utils.py (#9815)



* Refac training utils.py

* quality

---------
Co-authored-by: default avatarsayakpaul <spsayakpaul@gmail.com>
parent 13e8fdec
......@@ -43,6 +43,9 @@ def set_seed(seed: int):
Args:
seed (`int`): The seed to set.
Returns:
`None`
"""
random.seed(seed)
np.random.seed(seed)
......@@ -58,6 +61,17 @@ def compute_snr(noise_scheduler, timesteps):
"""
Computes SNR as per
https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849
for the given timesteps using the provided noise scheduler.
Args:
noise_scheduler (`NoiseScheduler`):
An object containing the noise schedule parameters, specifically `alphas_cumprod`, which is used to compute
the SNR values.
timesteps (`torch.Tensor`):
A tensor of timesteps for which the SNR is computed.
Returns:
`torch.Tensor`: A tensor containing the computed SNR values for each timestep.
"""
alphas_cumprod = noise_scheduler.alphas_cumprod
sqrt_alphas_cumprod = alphas_cumprod**0.5
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment