"git@developer.sourcefind.cn:OpenDAS/vision.git" did not exist on "db718021717b76def2725584303bdb5b221e2677"
Unverified Commit abbf3c1a authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

Allow fp16 attn for x4 upscaler (#3239)

* Add all files

* update

* Make sure vae is memory efficient for PT 1

* make style
parent da2ce1a6
...@@ -212,6 +212,7 @@ class Decoder(nn.Module): ...@@ -212,6 +212,7 @@ class Decoder(nn.Module):
sample = z sample = z
sample = self.conv_in(sample) sample = self.conv_in(sample)
upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
if self.training and self.gradient_checkpointing: if self.training and self.gradient_checkpointing:
def create_custom_forward(module): def create_custom_forward(module):
...@@ -222,6 +223,7 @@ class Decoder(nn.Module): ...@@ -222,6 +223,7 @@ class Decoder(nn.Module):
# middle # middle
sample = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), sample) sample = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), sample)
sample = sample.to(upscale_dtype)
# up # up
for up_block in self.up_blocks: for up_block in self.up_blocks:
...@@ -229,6 +231,7 @@ class Decoder(nn.Module): ...@@ -229,6 +231,7 @@ class Decoder(nn.Module):
else: else:
# middle # middle
sample = self.mid_block(sample) sample = self.mid_block(sample)
sample = sample.to(upscale_dtype)
# up # up
for up_block in self.up_blocks: for up_block in self.up_blocks:
......
...@@ -18,6 +18,7 @@ from typing import Any, Callable, List, Optional, Union ...@@ -18,6 +18,7 @@ from typing import Any, Callable, List, Optional, Union
import numpy as np import numpy as np
import PIL import PIL
import torch import torch
import torch.nn.functional as F
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
from ...loaders import TextualInversionLoaderMixin from ...loaders import TextualInversionLoaderMixin
...@@ -698,10 +699,22 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline, TextualInversionLoaderMi ...@@ -698,10 +699,22 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline, TextualInversionLoaderMi
# make sure the VAE is in float32 mode, as it overflows in float16 # make sure the VAE is in float32 mode, as it overflows in float16
self.vae.to(dtype=torch.float32) self.vae.to(dtype=torch.float32)
# TODO(Patrick, William) - clean up when attention is refactored
use_torch_2_0_attn = hasattr(F, "scaled_dot_product_attention")
use_xformers = self.vae.decoder.mid_block.attentions[0]._use_memory_efficient_attention_xformers
# if xformers or torch_2_0 is used attention block does not need
# to be in float32 which can save lots of memory
if not use_torch_2_0_attn and not use_xformers:
self.vae.post_quant_conv.to(latents.dtype)
self.vae.decoder.conv_in.to(latents.dtype)
self.vae.decoder.mid_block.to(latents.dtype)
else:
latents = latents.float()
# 11. Convert to PIL # 11. Convert to PIL
# has_nsfw_concept = False
if output_type == "pil": if output_type == "pil":
image = self.decode_latents(latents.float()) image = self.decode_latents(latents)
image, has_nsfw_concept, _ = self.run_safety_checker(image, device, prompt_embeds.dtype) image, has_nsfw_concept, _ = self.run_safety_checker(image, device, prompt_embeds.dtype)
image = self.numpy_to_pil(image) image = self.numpy_to_pil(image)
...@@ -710,11 +723,11 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline, TextualInversionLoaderMi ...@@ -710,11 +723,11 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline, TextualInversionLoaderMi
if self.watermarker is not None: if self.watermarker is not None:
image = self.watermarker.apply_watermark(image) image = self.watermarker.apply_watermark(image)
elif output_type == "pt": elif output_type == "pt":
latents = 1 / self.vae.config.scaling_factor * latents.float() latents = 1 / self.vae.config.scaling_factor * latents
image = self.vae.decode(latents).sample image = self.vae.decode(latents).sample
has_nsfw_concept = None has_nsfw_concept = None
else: else:
image = self.decode_latents(latents.float()) image = self.decode_latents(latents)
has_nsfw_concept = None has_nsfw_concept = None
# Offload last model to CPU # Offload last model to CPU
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment