Unverified Commit e4546fd5 authored by Aryan's avatar Aryan Committed by GitHub
Browse files

[docs] Add missing copied from statements in TCD Scheduler (#7360)



* add missing copied from statements in tcd scheduler

* update docstring

---------
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent d44e31ae
...@@ -307,6 +307,7 @@ class TCDScheduler(SchedulerMixin, ConfigMixin): ...@@ -307,6 +307,7 @@ class TCDScheduler(SchedulerMixin, ConfigMixin):
The input sample. The input sample.
timestep (`int`, *optional*): timestep (`int`, *optional*):
The current timestep in the diffusion chain. The current timestep in the diffusion chain.
Returns: Returns:
`torch.FloatTensor`: `torch.FloatTensor`:
A scaled input sample. A scaled input sample.
...@@ -364,7 +365,7 @@ class TCDScheduler(SchedulerMixin, ConfigMixin): ...@@ -364,7 +365,7 @@ class TCDScheduler(SchedulerMixin, ConfigMixin):
device: Union[str, torch.device] = None, device: Union[str, torch.device] = None,
original_inference_steps: Optional[int] = None, original_inference_steps: Optional[int] = None,
timesteps: Optional[List[int]] = None, timesteps: Optional[List[int]] = None,
strength: int = 1.0, strength: float = 1.0,
): ):
""" """
Sets the discrete timesteps used for the diffusion chain (to be run before inference). Sets the discrete timesteps used for the diffusion chain (to be run before inference).
...@@ -384,6 +385,8 @@ class TCDScheduler(SchedulerMixin, ConfigMixin): ...@@ -384,6 +385,8 @@ class TCDScheduler(SchedulerMixin, ConfigMixin):
Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
timestep spacing strategy of equal spacing between timesteps on the training/distillation timestep timestep spacing strategy of equal spacing between timesteps on the training/distillation timestep
schedule is used. If `timesteps` is passed, `num_inference_steps` must be `None`. schedule is used. If `timesteps` is passed, `num_inference_steps` must be `None`.
strength (`float`, *optional*, defaults to 1.0):
Used to determine the number of timesteps used for inference when using img2img, inpaint, etc.
""" """
# 0. Check inputs # 0. Check inputs
if num_inference_steps is None and timesteps is None: if num_inference_steps is None and timesteps is None:
...@@ -624,6 +627,7 @@ class TCDScheduler(SchedulerMixin, ConfigMixin): ...@@ -624,6 +627,7 @@ class TCDScheduler(SchedulerMixin, ConfigMixin):
return TCDSchedulerOutput(prev_sample=prev_sample, pred_noised_sample=pred_noised_sample) return TCDSchedulerOutput(prev_sample=prev_sample, pred_noised_sample=pred_noised_sample)
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise
def add_noise( def add_noise(
self, self,
original_samples: torch.FloatTensor, original_samples: torch.FloatTensor,
...@@ -631,7 +635,10 @@ class TCDScheduler(SchedulerMixin, ConfigMixin): ...@@ -631,7 +635,10 @@ class TCDScheduler(SchedulerMixin, ConfigMixin):
timesteps: torch.IntTensor, timesteps: torch.IntTensor,
) -> torch.FloatTensor: ) -> torch.FloatTensor:
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples # Make sure alphas_cumprod and timestep have same device and dtype as original_samples
alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype) # Move the self.alphas_cumprod to device to avoid redundant CPU to GPU data movement
# for the subsequent add_noise calls
self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device)
alphas_cumprod = self.alphas_cumprod.to(dtype=original_samples.dtype)
timesteps = timesteps.to(original_samples.device) timesteps = timesteps.to(original_samples.device)
sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
...@@ -647,11 +654,13 @@ class TCDScheduler(SchedulerMixin, ConfigMixin): ...@@ -647,11 +654,13 @@ class TCDScheduler(SchedulerMixin, ConfigMixin):
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
return noisy_samples return noisy_samples
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.get_velocity
def get_velocity( def get_velocity(
self, sample: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.IntTensor self, sample: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.IntTensor
) -> torch.FloatTensor: ) -> torch.FloatTensor:
# Make sure alphas_cumprod and timestep have same device and dtype as sample # Make sure alphas_cumprod and timestep have same device and dtype as sample
alphas_cumprod = self.alphas_cumprod.to(device=sample.device, dtype=sample.dtype) self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device)
alphas_cumprod = self.alphas_cumprod.to(dtype=sample.dtype)
timesteps = timesteps.to(sample.device) timesteps = timesteps.to(sample.device)
sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
...@@ -670,6 +679,7 @@ class TCDScheduler(SchedulerMixin, ConfigMixin): ...@@ -670,6 +679,7 @@ class TCDScheduler(SchedulerMixin, ConfigMixin):
def __len__(self): def __len__(self):
return self.config.num_train_timesteps return self.config.num_train_timesteps
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.previous_timestep
def previous_timestep(self, timestep): def previous_timestep(self, timestep):
if self.custom_timesteps: if self.custom_timesteps:
index = (self.timesteps == timestep).nonzero(as_tuple=True)[0][0] index = (self.timesteps == timestep).nonzero(as_tuple=True)[0][0]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment