Unverified Commit c7514490 authored by Jianqi Pan's avatar Jianqi Pan Committed by GitHub
Browse files

fix: use retrieve_latents (#6337)

parent c1e8bdf1
......@@ -50,6 +50,7 @@ from diffusers.pipelines.stable_diffusion import (
StableDiffusionPipelineOutput,
StableDiffusionSafetyChecker,
)
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img import retrieve_latents
from diffusers.schedulers import DDIMScheduler
from diffusers.utils import logging
......@@ -608,7 +609,7 @@ class TorchVAEEncoder(torch.nn.Module):
self.vae_encoder = model
def forward(self, x):
return self.vae_encoder.encode(x).latent_dist.sample()
return retrieve_latents(self.vae_encoder.encode(x))
class VAEEncoder(BaseModel):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment