Unverified Commit 9f28f1ab authored by Vedat Baday's avatar Vedat Baday Committed by GitHub
Browse files

feat(training-utils): support device and dtype params in...


feat(training-utils): support device and dtype params in compute_density_for_timestep_sampling (#10699)

* feat(training-utils): support device and dtype params in compute_density_for_timestep_sampling

* chore: update type hint

* refactor: use union for type hint

---------
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent 5d2d2398
...@@ -248,7 +248,13 @@ def _set_state_dict_into_text_encoder( ...@@ -248,7 +248,13 @@ def _set_state_dict_into_text_encoder(
def compute_density_for_timestep_sampling( def compute_density_for_timestep_sampling(
weighting_scheme: str, batch_size: int, logit_mean: float = None, logit_std: float = None, mode_scale: float = None weighting_scheme: str,
batch_size: int,
logit_mean: float = None,
logit_std: float = None,
mode_scale: float = None,
device: Union[torch.device, str] = "cpu",
generator: Optional[torch.Generator] = None,
): ):
""" """
Compute the density for sampling the timesteps when doing SD3 training. Compute the density for sampling the timesteps when doing SD3 training.
...@@ -258,14 +264,13 @@ def compute_density_for_timestep_sampling( ...@@ -258,14 +264,13 @@ def compute_density_for_timestep_sampling(
SD3 paper reference: https://arxiv.org/abs/2403.03206v1. SD3 paper reference: https://arxiv.org/abs/2403.03206v1.
""" """
if weighting_scheme == "logit_normal": if weighting_scheme == "logit_normal":
# See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$). u = torch.normal(mean=logit_mean, std=logit_std, size=(batch_size,), device=device, generator=generator)
u = torch.normal(mean=logit_mean, std=logit_std, size=(batch_size,), device="cpu")
u = torch.nn.functional.sigmoid(u) u = torch.nn.functional.sigmoid(u)
elif weighting_scheme == "mode": elif weighting_scheme == "mode":
u = torch.rand(size=(batch_size,), device="cpu") u = torch.rand(size=(batch_size,), device=device, generator=generator)
u = 1 - u - mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u) u = 1 - u - mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u)
else: else:
u = torch.rand(size=(batch_size,), device="cpu") u = torch.rand(size=(batch_size,), device=device, generator=generator)
return u return u
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment