"src/vscode:/vscode.git/clone" did not exist on "f653ded7eda647a3f48b2d5eddff140f0ef2af4e"
Unverified Commit 5dc34713 authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[SVD] support generators that are created on GPU (#6484)

* debug generator

* fix?

* fix?

* fix

* remove print.

* revert none check
parent 9df566e6
......@@ -429,15 +429,20 @@ class StableVideoDiffusionPipeline(DiffusionPipeline):
fps = fps - 1
# 4. Encode input image using VAE
image = self.image_processor.preprocess(image, height=height, width=width)
noise = randn_tensor(image.shape, generator=generator, device=image.device, dtype=image.dtype)
image = self.image_processor.preprocess(image, height=height, width=width).to(device)
noise = randn_tensor(image.shape, generator=generator, device=device, dtype=image.dtype)
image = image + noise_aug_strength * noise
needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
if needs_upcasting:
self.vae.to(dtype=torch.float32)
image_latents = self._encode_vae_image(image, device, num_videos_per_prompt, self.do_classifier_free_guidance)
image_latents = self._encode_vae_image(
image,
device=device,
num_videos_per_prompt=num_videos_per_prompt,
do_classifier_free_guidance=self.do_classifier_free_guidance,
)
image_latents = image_latents.to(image_embeddings.dtype)
# cast back to fp16 if needed
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment