Unverified Commit da2ce1a6 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

Allow return pt x4 (#3236)

* Add all files

* update
parent e51f19ae
...@@ -697,15 +697,11 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline, TextualInversionLoaderMi ...@@ -697,15 +697,11 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline, TextualInversionLoaderMi
# 10. Post-processing # 10. Post-processing
# make sure the VAE is in float32 mode, as it overflows in float16 # make sure the VAE is in float32 mode, as it overflows in float16
self.vae.to(dtype=torch.float32) self.vae.to(dtype=torch.float32)
image = self.decode_latents(latents.float())
# Offload last model to CPU
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
self.final_offload_hook.offload()
# 11. Convert to PIL # 11. Convert to PIL
# has_nsfw_concept = False # has_nsfw_concept = False
if output_type == "pil": if output_type == "pil":
image = self.decode_latents(latents.float())
image, has_nsfw_concept, _ = self.run_safety_checker(image, device, prompt_embeds.dtype) image, has_nsfw_concept, _ = self.run_safety_checker(image, device, prompt_embeds.dtype)
image = self.numpy_to_pil(image) image = self.numpy_to_pil(image)
...@@ -713,9 +709,18 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline, TextualInversionLoaderMi ...@@ -713,9 +709,18 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline, TextualInversionLoaderMi
# 11. Apply watermark # 11. Apply watermark
if self.watermarker is not None: if self.watermarker is not None:
image = self.watermarker.apply_watermark(image) image = self.watermarker.apply_watermark(image)
elif output_type == "pt":
latents = 1 / self.vae.config.scaling_factor * latents.float()
image = self.vae.decode(latents).sample
has_nsfw_concept = None
else: else:
image = self.decode_latents(latents.float())
has_nsfw_concept = None has_nsfw_concept = None
# Offload last model to CPU
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
self.final_offload_hook.offload()
if not return_dict: if not return_dict:
return (image, has_nsfw_concept) return (image, has_nsfw_concept)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment