Commit 8fdecfab authored by patil-suraj's avatar patil-suraj
Browse files

fix noise device

parent cdb3c493
...@@ -943,7 +943,7 @@ class LatentDiffusion(DiffusionPipeline): ...@@ -943,7 +943,7 @@ class LatentDiffusion(DiffusionPipeline):
# 3. optionally sample variance # 3. optionally sample variance
variance = 0 variance = 0
if eta > 0: if eta > 0:
noise = torch.randn(image.shape, generator=generator, device=image.device) noise = torch.randn(image.shape, generator=generator)to(image.device)
variance = self.noise_scheduler.get_variance(t, num_inference_steps).sqrt() * eta * noise variance = self.noise_scheduler.get_variance(t, num_inference_steps).sqrt() * eta * noise
# 4. set current image to prev_image: x_t -> x_t-1 # 4. set current image to prev_image: x_t -> x_t-1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment