Unverified Commit c7514490 authored by Jianqi Pan's avatar Jianqi Pan Committed by GitHub
Browse files

fix: use retrieve_latents (#6337)

parent c1e8bdf1
...@@ -50,6 +50,7 @@ from diffusers.pipelines.stable_diffusion import ( ...@@ -50,6 +50,7 @@ from diffusers.pipelines.stable_diffusion import (
StableDiffusionPipelineOutput, StableDiffusionPipelineOutput,
StableDiffusionSafetyChecker, StableDiffusionSafetyChecker,
) )
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img import retrieve_latents
from diffusers.schedulers import DDIMScheduler from diffusers.schedulers import DDIMScheduler
from diffusers.utils import logging from diffusers.utils import logging
...@@ -608,7 +609,7 @@ class TorchVAEEncoder(torch.nn.Module): ...@@ -608,7 +609,7 @@ class TorchVAEEncoder(torch.nn.Module):
self.vae_encoder = model self.vae_encoder = model
def forward(self, x): def forward(self, x):
return self.vae_encoder.encode(x).latent_dist.sample() return retrieve_latents(self.vae_encoder.encode(x))
class VAEEncoder(BaseModel): class VAEEncoder(BaseModel):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment