Unverified Commit 0db766ba authored by Yunxuan Xiao's avatar Yunxuan Xiao Committed by GitHub
Browse files

[DDPMScheduler] Load `alpha_cumprod` to device to avoid redundant data movement. (#6704)



* load cumprod tensor to device
Signed-off-by: default avatarwoshiyyya <xiaoyunxuan1998@gmail.com>

* fixing ci
Signed-off-by: default avatarwoshiyyya <xiaoyunxuan1998@gmail.com>

* make fix-copies
Signed-off-by: default avatarwoshiyyya <xiaoyunxuan1998@gmail.com>

---------
Signed-off-by: default avatarwoshiyyya <xiaoyunxuan1998@gmail.com>
parent 8e946635
...@@ -477,7 +477,10 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): ...@@ -477,7 +477,10 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
timesteps: torch.IntTensor, timesteps: torch.IntTensor,
) -> torch.FloatTensor: ) -> torch.FloatTensor:
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples # Make sure alphas_cumprod and timestep have same device and dtype as original_samples
alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype) # Move the self.alphas_cumprod to device to avoid redundant CPU to GPU data movement
# for the subsequent add_noise calls
self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device)
alphas_cumprod = self.alphas_cumprod.to(dtype=original_samples.dtype)
timesteps = timesteps.to(original_samples.device) timesteps = timesteps.to(original_samples.device)
sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
...@@ -498,7 +501,8 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): ...@@ -498,7 +501,8 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
self, sample: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.IntTensor self, sample: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.IntTensor
) -> torch.FloatTensor: ) -> torch.FloatTensor:
# Make sure alphas_cumprod and timestep have same device and dtype as sample # Make sure alphas_cumprod and timestep have same device and dtype as sample
alphas_cumprod = self.alphas_cumprod.to(device=sample.device, dtype=sample.dtype) self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device)
alphas_cumprod = self.alphas_cumprod.to(dtype=sample.dtype)
timesteps = timesteps.to(sample.device) timesteps = timesteps.to(sample.device)
sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
......
...@@ -602,7 +602,10 @@ class DDIMParallelScheduler(SchedulerMixin, ConfigMixin): ...@@ -602,7 +602,10 @@ class DDIMParallelScheduler(SchedulerMixin, ConfigMixin):
timesteps: torch.IntTensor, timesteps: torch.IntTensor,
) -> torch.FloatTensor: ) -> torch.FloatTensor:
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples # Make sure alphas_cumprod and timestep have same device and dtype as original_samples
alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype) # Move the self.alphas_cumprod to device to avoid redundant CPU to GPU data movement
# for the subsequent add_noise calls
self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device)
alphas_cumprod = self.alphas_cumprod.to(dtype=original_samples.dtype)
timesteps = timesteps.to(original_samples.device) timesteps = timesteps.to(original_samples.device)
sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
...@@ -623,7 +626,8 @@ class DDIMParallelScheduler(SchedulerMixin, ConfigMixin): ...@@ -623,7 +626,8 @@ class DDIMParallelScheduler(SchedulerMixin, ConfigMixin):
self, sample: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.IntTensor self, sample: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.IntTensor
) -> torch.FloatTensor: ) -> torch.FloatTensor:
# Make sure alphas_cumprod and timestep have same device and dtype as sample # Make sure alphas_cumprod and timestep have same device and dtype as sample
alphas_cumprod = self.alphas_cumprod.to(device=sample.device, dtype=sample.dtype) self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device)
alphas_cumprod = self.alphas_cumprod.to(dtype=sample.dtype)
timesteps = timesteps.to(sample.device) timesteps = timesteps.to(sample.device)
sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
......
...@@ -503,7 +503,10 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): ...@@ -503,7 +503,10 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
timesteps: torch.IntTensor, timesteps: torch.IntTensor,
) -> torch.FloatTensor: ) -> torch.FloatTensor:
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples # Make sure alphas_cumprod and timestep have same device and dtype as original_samples
alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype) # Move the self.alphas_cumprod to device to avoid redundant CPU to GPU data movement
# for the subsequent add_noise calls
self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device)
alphas_cumprod = self.alphas_cumprod.to(dtype=original_samples.dtype)
timesteps = timesteps.to(original_samples.device) timesteps = timesteps.to(original_samples.device)
sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
...@@ -523,7 +526,8 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): ...@@ -523,7 +526,8 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
self, sample: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.IntTensor self, sample: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.IntTensor
) -> torch.FloatTensor: ) -> torch.FloatTensor:
# Make sure alphas_cumprod and timestep have same device and dtype as sample # Make sure alphas_cumprod and timestep have same device and dtype as sample
alphas_cumprod = self.alphas_cumprod.to(device=sample.device, dtype=sample.dtype) self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device)
alphas_cumprod = self.alphas_cumprod.to(dtype=sample.dtype)
timesteps = timesteps.to(sample.device) timesteps = timesteps.to(sample.device)
sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
......
...@@ -594,7 +594,10 @@ class DDPMParallelScheduler(SchedulerMixin, ConfigMixin): ...@@ -594,7 +594,10 @@ class DDPMParallelScheduler(SchedulerMixin, ConfigMixin):
timesteps: torch.IntTensor, timesteps: torch.IntTensor,
) -> torch.FloatTensor: ) -> torch.FloatTensor:
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples # Make sure alphas_cumprod and timestep have same device and dtype as original_samples
alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype) # Move the self.alphas_cumprod to device to avoid redundant CPU to GPU data movement
# for the subsequent add_noise calls
self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device)
alphas_cumprod = self.alphas_cumprod.to(dtype=original_samples.dtype)
timesteps = timesteps.to(original_samples.device) timesteps = timesteps.to(original_samples.device)
sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
...@@ -615,7 +618,8 @@ class DDPMParallelScheduler(SchedulerMixin, ConfigMixin): ...@@ -615,7 +618,8 @@ class DDPMParallelScheduler(SchedulerMixin, ConfigMixin):
self, sample: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.IntTensor self, sample: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.IntTensor
) -> torch.FloatTensor: ) -> torch.FloatTensor:
# Make sure alphas_cumprod and timestep have same device and dtype as sample # Make sure alphas_cumprod and timestep have same device and dtype as sample
alphas_cumprod = self.alphas_cumprod.to(device=sample.device, dtype=sample.dtype) self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device)
alphas_cumprod = self.alphas_cumprod.to(dtype=sample.dtype)
timesteps = timesteps.to(sample.device) timesteps = timesteps.to(sample.device)
sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
......
...@@ -575,7 +575,10 @@ class LCMScheduler(SchedulerMixin, ConfigMixin): ...@@ -575,7 +575,10 @@ class LCMScheduler(SchedulerMixin, ConfigMixin):
timesteps: torch.IntTensor, timesteps: torch.IntTensor,
) -> torch.FloatTensor: ) -> torch.FloatTensor:
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples # Make sure alphas_cumprod and timestep have same device and dtype as original_samples
alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype) # Move the self.alphas_cumprod to device to avoid redundant CPU to GPU data movement
# for the subsequent add_noise calls
self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device)
alphas_cumprod = self.alphas_cumprod.to(dtype=original_samples.dtype)
timesteps = timesteps.to(original_samples.device) timesteps = timesteps.to(original_samples.device)
sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
...@@ -596,7 +599,8 @@ class LCMScheduler(SchedulerMixin, ConfigMixin): ...@@ -596,7 +599,8 @@ class LCMScheduler(SchedulerMixin, ConfigMixin):
self, sample: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.IntTensor self, sample: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.IntTensor
) -> torch.FloatTensor: ) -> torch.FloatTensor:
# Make sure alphas_cumprod and timestep have same device and dtype as sample # Make sure alphas_cumprod and timestep have same device and dtype as sample
alphas_cumprod = self.alphas_cumprod.to(device=sample.device, dtype=sample.dtype) self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device)
alphas_cumprod = self.alphas_cumprod.to(dtype=sample.dtype)
timesteps = timesteps.to(sample.device) timesteps = timesteps.to(sample.device)
sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
......
...@@ -455,7 +455,10 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): ...@@ -455,7 +455,10 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
timesteps: torch.IntTensor, timesteps: torch.IntTensor,
) -> torch.FloatTensor: ) -> torch.FloatTensor:
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples # Make sure alphas_cumprod and timestep have same device and dtype as original_samples
alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype) # Move the self.alphas_cumprod to device to avoid redundant CPU to GPU data movement
# for the subsequent add_noise calls
self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device)
alphas_cumprod = self.alphas_cumprod.to(dtype=original_samples.dtype)
timesteps = timesteps.to(original_samples.device) timesteps = timesteps.to(original_samples.device)
sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
......
...@@ -1069,7 +1069,10 @@ class SASolverScheduler(SchedulerMixin, ConfigMixin): ...@@ -1069,7 +1069,10 @@ class SASolverScheduler(SchedulerMixin, ConfigMixin):
timesteps: torch.IntTensor, timesteps: torch.IntTensor,
) -> torch.FloatTensor: ) -> torch.FloatTensor:
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples # Make sure alphas_cumprod and timestep have same device and dtype as original_samples
alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype) # Move the self.alphas_cumprod to device to avoid redundant CPU to GPU data movement
# for the subsequent add_noise calls
self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device)
alphas_cumprod = self.alphas_cumprod.to(dtype=original_samples.dtype)
timesteps = timesteps.to(original_samples.device) timesteps = timesteps.to(original_samples.device)
sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
......
...@@ -332,7 +332,10 @@ class UnCLIPScheduler(SchedulerMixin, ConfigMixin): ...@@ -332,7 +332,10 @@ class UnCLIPScheduler(SchedulerMixin, ConfigMixin):
timesteps: torch.IntTensor, timesteps: torch.IntTensor,
) -> torch.FloatTensor: ) -> torch.FloatTensor:
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples # Make sure alphas_cumprod and timestep have same device and dtype as original_samples
alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype) # Move the self.alphas_cumprod to device to avoid redundant CPU to GPU data movement
# for the subsequent add_noise calls
self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device)
alphas_cumprod = self.alphas_cumprod.to(dtype=original_samples.dtype)
timesteps = timesteps.to(original_samples.device) timesteps = timesteps.to(original_samples.device)
sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment