Unverified Commit 5d970a4a authored by hlky's avatar hlky Committed by GitHub
Browse files

WanI2V encode_image (#11164)

* WanI2V encode_image
parent de6a88c2
......@@ -220,8 +220,13 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
return prompt_embeds
def encode_image(self, image: PipelineImageInput):
image = self.image_processor(images=image, return_tensors="pt").to(self.device)
def encode_image(
self,
image: PipelineImageInput,
device: Optional[torch.device] = None,
):
device = device or self._execution_device
image = self.image_processor(images=image, return_tensors="pt").to(device)
image_embeds = self.image_encoder(**image, output_hidden_states=True)
return image_embeds.hidden_states[-2]
......@@ -587,7 +592,7 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
if negative_prompt_embeds is not None:
negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype)
image_embeds = self.encode_image(image)
image_embeds = self.encode_image(image, device)
image_embeds = image_embeds.repeat(batch_size, 1, 1)
image_embeds = image_embeds.to(transformer_dtype)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment