Unverified Commit 0e82fb19 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

Torch compile graph fix (#3286)

* fix more

* Fix more

* fix more

* Apply suggestions from code review

* fix

* make style

* make fix-copies

* fix

* make sure torch compile

* Clean

* fix test
parent 709cf554
...@@ -18,6 +18,7 @@ import torch ...@@ -18,6 +18,7 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch import nn from torch import nn
from ..utils import maybe_allow_in_graph
from ..utils.import_utils import is_xformers_available from ..utils.import_utils import is_xformers_available
from .attention_processor import Attention from .attention_processor import Attention
from .embeddings import CombinedTimestepLabelEmbeddings from .embeddings import CombinedTimestepLabelEmbeddings
...@@ -193,6 +194,7 @@ class AttentionBlock(nn.Module): ...@@ -193,6 +194,7 @@ class AttentionBlock(nn.Module):
return hidden_states return hidden_states
@maybe_allow_in_graph
class BasicTransformerBlock(nn.Module): class BasicTransformerBlock(nn.Module):
r""" r"""
A basic Transformer block. A basic Transformer block.
......
...@@ -17,7 +17,7 @@ import torch ...@@ -17,7 +17,7 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch import nn from torch import nn
from ..utils import deprecate, logging from ..utils import deprecate, logging, maybe_allow_in_graph
from ..utils.import_utils import is_xformers_available from ..utils.import_utils import is_xformers_available
...@@ -31,6 +31,7 @@ else: ...@@ -31,6 +31,7 @@ else:
xformers = None xformers = None
@maybe_allow_in_graph
class Attention(nn.Module): class Attention(nn.Module):
r""" r"""
A cross attention layer. A cross attention layer.
......
...@@ -77,8 +77,14 @@ def get_parameter_device(parameter: torch.nn.Module): ...@@ -77,8 +77,14 @@ def get_parameter_device(parameter: torch.nn.Module):
def get_parameter_dtype(parameter: torch.nn.Module): def get_parameter_dtype(parameter: torch.nn.Module):
try: try:
parameters_and_buffers = itertools.chain(parameter.parameters(), parameter.buffers()) params = tuple(parameter.parameters())
return next(parameters_and_buffers).dtype if len(params) > 0:
return params[0].dtype
buffers = tuple(parameter.buffers())
if len(buffers) > 0:
return buffers[0].dtype
except StopIteration: except StopIteration:
# For torch.nn.DataParallel compatibility in PyTorch 1.5 # For torch.nn.DataParallel compatibility in PyTorch 1.5
......
...@@ -560,7 +560,8 @@ class UNetMidBlock2DCrossAttn(nn.Module): ...@@ -560,7 +560,8 @@ class UNetMidBlock2DCrossAttn(nn.Module):
hidden_states, hidden_states,
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs, cross_attention_kwargs=cross_attention_kwargs,
).sample return_dict=False,
)[0]
hidden_states = resnet(hidden_states, temb) hidden_states = resnet(hidden_states, temb)
return hidden_states return hidden_states
...@@ -868,15 +869,16 @@ class CrossAttnDownBlock2D(nn.Module): ...@@ -868,15 +869,16 @@ class CrossAttnDownBlock2D(nn.Module):
hidden_states, hidden_states,
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs, cross_attention_kwargs=cross_attention_kwargs,
).sample return_dict=False,
)[0]
output_states += (hidden_states,) output_states = output_states + (hidden_states,)
if self.downsamplers is not None: if self.downsamplers is not None:
for downsampler in self.downsamplers: for downsampler in self.downsamplers:
hidden_states = downsampler(hidden_states) hidden_states = downsampler(hidden_states)
output_states += (hidden_states,) output_states = output_states + (hidden_states,)
return hidden_states, output_states return hidden_states, output_states
...@@ -949,13 +951,13 @@ class DownBlock2D(nn.Module): ...@@ -949,13 +951,13 @@ class DownBlock2D(nn.Module):
else: else:
hidden_states = resnet(hidden_states, temb) hidden_states = resnet(hidden_states, temb)
output_states += (hidden_states,) output_states = output_states + (hidden_states,)
if self.downsamplers is not None: if self.downsamplers is not None:
for downsampler in self.downsamplers: for downsampler in self.downsamplers:
hidden_states = downsampler(hidden_states) hidden_states = downsampler(hidden_states)
output_states += (hidden_states,) output_states = output_states + (hidden_states,)
return hidden_states, output_states return hidden_states, output_states
...@@ -1342,13 +1344,13 @@ class ResnetDownsampleBlock2D(nn.Module): ...@@ -1342,13 +1344,13 @@ class ResnetDownsampleBlock2D(nn.Module):
else: else:
hidden_states = resnet(hidden_states, temb) hidden_states = resnet(hidden_states, temb)
output_states += (hidden_states,) output_states = output_states + (hidden_states,)
if self.downsamplers is not None: if self.downsamplers is not None:
for downsampler in self.downsamplers: for downsampler in self.downsamplers:
hidden_states = downsampler(hidden_states, temb) hidden_states = downsampler(hidden_states, temb)
output_states += (hidden_states,) output_states = output_states + (hidden_states,)
return hidden_states, output_states return hidden_states, output_states
...@@ -1466,13 +1468,13 @@ class SimpleCrossAttnDownBlock2D(nn.Module): ...@@ -1466,13 +1468,13 @@ class SimpleCrossAttnDownBlock2D(nn.Module):
**cross_attention_kwargs, **cross_attention_kwargs,
) )
output_states += (hidden_states,) output_states = output_states + (hidden_states,)
if self.downsamplers is not None: if self.downsamplers is not None:
for downsampler in self.downsamplers: for downsampler in self.downsamplers:
hidden_states = downsampler(hidden_states, temb) hidden_states = downsampler(hidden_states, temb)
output_states += (hidden_states,) output_states = output_states + (hidden_states,)
return hidden_states, output_states return hidden_states, output_states
...@@ -1859,7 +1861,8 @@ class CrossAttnUpBlock2D(nn.Module): ...@@ -1859,7 +1861,8 @@ class CrossAttnUpBlock2D(nn.Module):
hidden_states, hidden_states,
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs, cross_attention_kwargs=cross_attention_kwargs,
).sample return_dict=False,
)[0]
if self.upsamplers is not None: if self.upsamplers is not None:
for upsampler in self.upsamplers: for upsampler in self.upsamplers:
......
...@@ -682,7 +682,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) ...@@ -682,7 +682,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
# `Timesteps` does not contain any weights and will always return f32 tensors # `Timesteps` does not contain any weights and will always return f32 tensors
# but time_embedding might actually be running in fp16. so we need to cast here. # but time_embedding might actually be running in fp16. so we need to cast here.
# there might be better ways to encapsulate this. # there might be better ways to encapsulate this.
t_emb = t_emb.to(dtype=self.dtype) t_emb = t_emb.to(dtype=sample.dtype)
emb = self.time_embedding(t_emb, timestep_cond) emb = self.time_embedding(t_emb, timestep_cond)
...@@ -697,7 +697,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) ...@@ -697,7 +697,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
# there might be better ways to encapsulate this. # there might be better ways to encapsulate this.
class_labels = class_labels.to(dtype=sample.dtype) class_labels = class_labels.to(dtype=sample.dtype)
class_emb = self.class_embedding(class_labels).to(dtype=self.dtype) class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype)
if self.config.class_embeddings_concat: if self.config.class_embeddings_concat:
emb = torch.cat([emb, class_emb], dim=-1) emb = torch.cat([emb, class_emb], dim=-1)
......
...@@ -437,7 +437,7 @@ class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin): ...@@ -437,7 +437,7 @@ class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
def decode_latents(self, latents): def decode_latents(self, latents):
latents = 1 / self.vae.config.scaling_factor * latents latents = 1 / self.vae.config.scaling_factor * latents
image = self.vae.decode(latents).sample image = self.vae.decode(latents, return_dict=False)[0]
image = (image / 2 + 0.5).clamp(0, 1) image = (image / 2 + 0.5).clamp(0, 1)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
image = image.cpu().permute(0, 2, 3, 1).float().numpy() image = image.cpu().permute(0, 2, 3, 1).float().numpy()
...@@ -683,7 +683,8 @@ class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin): ...@@ -683,7 +683,8 @@ class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
t, t,
encoder_hidden_states=prompt_embeds, encoder_hidden_states=prompt_embeds,
cross_attention_kwargs=cross_attention_kwargs, cross_attention_kwargs=cross_attention_kwargs,
).sample return_dict=False,
)[0]
# perform guidance # perform guidance
if do_classifier_free_guidance: if do_classifier_free_guidance:
...@@ -691,7 +692,7 @@ class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin): ...@@ -691,7 +692,7 @@ class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1 # compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
# call the callback, if provided # call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
......
...@@ -793,7 +793,8 @@ class IFPipeline(DiffusionPipeline): ...@@ -793,7 +793,8 @@ class IFPipeline(DiffusionPipeline):
t, t,
encoder_hidden_states=prompt_embeds, encoder_hidden_states=prompt_embeds,
cross_attention_kwargs=cross_attention_kwargs, cross_attention_kwargs=cross_attention_kwargs,
).sample return_dict=False,
)[0]
# perform guidance # perform guidance
if do_classifier_free_guidance: if do_classifier_free_guidance:
...@@ -805,8 +806,8 @@ class IFPipeline(DiffusionPipeline): ...@@ -805,8 +806,8 @@ class IFPipeline(DiffusionPipeline):
# compute the previous noisy sample x_t -> x_t-1 # compute the previous noisy sample x_t -> x_t-1
intermediate_images = self.scheduler.step( intermediate_images = self.scheduler.step(
noise_pred, t, intermediate_images, **extra_step_kwargs noise_pred, t, intermediate_images, **extra_step_kwargs, return_dict=False
).prev_sample )[0]
# call the callback, if provided # call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
...@@ -829,7 +830,7 @@ class IFPipeline(DiffusionPipeline): ...@@ -829,7 +830,7 @@ class IFPipeline(DiffusionPipeline):
# 11. Apply watermark # 11. Apply watermark
if self.watermarker is not None: if self.watermarker is not None:
self.watermarker.apply_watermark(image, self.unet.config.sample_size) image = self.watermarker.apply_watermark(image, self.unet.config.sample_size)
elif output_type == "pt": elif output_type == "pt":
nsfw_detected = None nsfw_detected = None
watermark_detected = None watermark_detected = None
......
...@@ -256,7 +256,7 @@ class PaintByExamplePipeline(DiffusionPipeline): ...@@ -256,7 +256,7 @@ class PaintByExamplePipeline(DiffusionPipeline):
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
def decode_latents(self, latents): def decode_latents(self, latents):
latents = 1 / self.vae.config.scaling_factor * latents latents = 1 / self.vae.config.scaling_factor * latents
image = self.vae.decode(latents).sample image = self.vae.decode(latents, return_dict=False)[0]
image = (image / 2 + 0.5).clamp(0, 1) image = (image / 2 + 0.5).clamp(0, 1)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
image = image.cpu().permute(0, 2, 3, 1).float().numpy() image = image.cpu().permute(0, 2, 3, 1).float().numpy()
......
...@@ -134,7 +134,7 @@ class SemanticStableDiffusionPipeline(DiffusionPipeline): ...@@ -134,7 +134,7 @@ class SemanticStableDiffusionPipeline(DiffusionPipeline):
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
def decode_latents(self, latents): def decode_latents(self, latents):
latents = 1 / self.vae.config.scaling_factor * latents latents = 1 / self.vae.config.scaling_factor * latents
image = self.vae.decode(latents).sample image = self.vae.decode(latents, return_dict=False)[0]
image = (image / 2 + 0.5).clamp(0, 1) image = (image / 2 + 0.5).clamp(0, 1)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
image = image.cpu().permute(0, 2, 3, 1).float().numpy() image = image.cpu().permute(0, 2, 3, 1).float().numpy()
......
...@@ -516,7 +516,7 @@ class CycleDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin): ...@@ -516,7 +516,7 @@ class CycleDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
def decode_latents(self, latents): def decode_latents(self, latents):
latents = 1 / self.vae.config.scaling_factor * latents latents = 1 / self.vae.config.scaling_factor * latents
image = self.vae.decode(latents).sample image = self.vae.decode(latents, return_dict=False)[0]
image = (image / 2 + 0.5).clamp(0, 1) image = (image / 2 + 0.5).clamp(0, 1)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
image = image.cpu().permute(0, 2, 3, 1).float().numpy() image = image.cpu().permute(0, 2, 3, 1).float().numpy()
......
...@@ -440,7 +440,7 @@ class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo ...@@ -440,7 +440,7 @@ class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo
def decode_latents(self, latents): def decode_latents(self, latents):
latents = 1 / self.vae.config.scaling_factor * latents latents = 1 / self.vae.config.scaling_factor * latents
image = self.vae.decode(latents).sample image = self.vae.decode(latents, return_dict=False)[0]
image = (image / 2 + 0.5).clamp(0, 1) image = (image / 2 + 0.5).clamp(0, 1)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
image = image.cpu().permute(0, 2, 3, 1).float().numpy() image = image.cpu().permute(0, 2, 3, 1).float().numpy()
...@@ -686,7 +686,8 @@ class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo ...@@ -686,7 +686,8 @@ class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo
t, t,
encoder_hidden_states=prompt_embeds, encoder_hidden_states=prompt_embeds,
cross_attention_kwargs=cross_attention_kwargs, cross_attention_kwargs=cross_attention_kwargs,
).sample return_dict=False,
)[0]
# perform guidance # perform guidance
if do_classifier_free_guidance: if do_classifier_free_guidance:
...@@ -694,7 +695,7 @@ class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo ...@@ -694,7 +695,7 @@ class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1 # compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
# call the callback, if provided # call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
......
...@@ -454,7 +454,7 @@ class StableDiffusionAttendAndExcitePipeline(DiffusionPipeline, TextualInversion ...@@ -454,7 +454,7 @@ class StableDiffusionAttendAndExcitePipeline(DiffusionPipeline, TextualInversion
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
def decode_latents(self, latents): def decode_latents(self, latents):
latents = 1 / self.vae.config.scaling_factor * latents latents = 1 / self.vae.config.scaling_factor * latents
image = self.vae.decode(latents).sample image = self.vae.decode(latents, return_dict=False)[0]
image = (image / 2 + 0.5).clamp(0, 1) image = (image / 2 + 0.5).clamp(0, 1)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
image = image.cpu().permute(0, 2, 3, 1).float().numpy() image = image.cpu().permute(0, 2, 3, 1).float().numpy()
......
...@@ -496,7 +496,7 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade ...@@ -496,7 +496,7 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
def decode_latents(self, latents): def decode_latents(self, latents):
latents = 1 / self.vae.config.scaling_factor * latents latents = 1 / self.vae.config.scaling_factor * latents
image = self.vae.decode(latents).sample image = self.vae.decode(latents, return_dict=False)[0]
image = (image / 2 + 0.5).clamp(0, 1) image = (image / 2 + 0.5).clamp(0, 1)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
image = image.cpu().permute(0, 2, 3, 1).float().numpy() image = image.cpu().permute(0, 2, 3, 1).float().numpy()
......
...@@ -326,7 +326,7 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline, TextualInversionLoader ...@@ -326,7 +326,7 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline, TextualInversionLoader
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
def decode_latents(self, latents): def decode_latents(self, latents):
latents = 1 / self.vae.config.scaling_factor * latents latents = 1 / self.vae.config.scaling_factor * latents
image = self.vae.decode(latents).sample image = self.vae.decode(latents, return_dict=False)[0]
image = (image / 2 + 0.5).clamp(0, 1) image = (image / 2 + 0.5).clamp(0, 1)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
image = image.cpu().permute(0, 2, 3, 1).float().numpy() image = image.cpu().permute(0, 2, 3, 1).float().numpy()
......
...@@ -648,7 +648,7 @@ class StableDiffusionDiffEditPipeline(DiffusionPipeline, TextualInversionLoaderM ...@@ -648,7 +648,7 @@ class StableDiffusionDiffEditPipeline(DiffusionPipeline, TextualInversionLoaderM
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
def decode_latents(self, latents): def decode_latents(self, latents):
latents = 1 / self.vae.config.scaling_factor * latents latents = 1 / self.vae.config.scaling_factor * latents
image = self.vae.decode(latents).sample image = self.vae.decode(latents, return_dict=False)[0]
image = (image / 2 + 0.5).clamp(0, 1) image = (image / 2 + 0.5).clamp(0, 1)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
image = image.cpu().permute(0, 2, 3, 1).float().numpy() image = image.cpu().permute(0, 2, 3, 1).float().numpy()
......
...@@ -195,7 +195,7 @@ class StableDiffusionImageVariationPipeline(DiffusionPipeline): ...@@ -195,7 +195,7 @@ class StableDiffusionImageVariationPipeline(DiffusionPipeline):
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
def decode_latents(self, latents): def decode_latents(self, latents):
latents = 1 / self.vae.config.scaling_factor * latents latents = 1 / self.vae.config.scaling_factor * latents
image = self.vae.decode(latents).sample image = self.vae.decode(latents, return_dict=False)[0]
image = (image / 2 + 0.5).clamp(0, 1) image = (image / 2 + 0.5).clamp(0, 1)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
image = image.cpu().permute(0, 2, 3, 1).float().numpy() image = image.cpu().permute(0, 2, 3, 1).float().numpy()
......
...@@ -525,7 +525,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline, TextualInversionLoaderMi ...@@ -525,7 +525,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline, TextualInversionLoaderMi
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
def decode_latents(self, latents): def decode_latents(self, latents):
latents = 1 / self.vae.config.scaling_factor * latents latents = 1 / self.vae.config.scaling_factor * latents
image = self.vae.decode(latents).sample image = self.vae.decode(latents, return_dict=False)[0]
image = (image / 2 + 0.5).clamp(0, 1) image = (image / 2 + 0.5).clamp(0, 1)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
image = image.cpu().permute(0, 2, 3, 1).float().numpy() image = image.cpu().permute(0, 2, 3, 1).float().numpy()
......
...@@ -446,7 +446,7 @@ class StableDiffusionInpaintPipelineLegacy( ...@@ -446,7 +446,7 @@ class StableDiffusionInpaintPipelineLegacy(
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
def decode_latents(self, latents): def decode_latents(self, latents):
latents = 1 / self.vae.config.scaling_factor * latents latents = 1 / self.vae.config.scaling_factor * latents
image = self.vae.decode(latents).sample image = self.vae.decode(latents, return_dict=False)[0]
image = (image / 2 + 0.5).clamp(0, 1) image = (image / 2 + 0.5).clamp(0, 1)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
image = image.cpu().permute(0, 2, 3, 1).float().numpy() image = image.cpu().permute(0, 2, 3, 1).float().numpy()
......
...@@ -656,7 +656,7 @@ class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline, TextualInversion ...@@ -656,7 +656,7 @@ class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline, TextualInversion
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
def decode_latents(self, latents): def decode_latents(self, latents):
latents = 1 / self.vae.config.scaling_factor * latents latents = 1 / self.vae.config.scaling_factor * latents
image = self.vae.decode(latents).sample image = self.vae.decode(latents, return_dict=False)[0]
image = (image / 2 + 0.5).clamp(0, 1) image = (image / 2 + 0.5).clamp(0, 1)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
image = image.cpu().permute(0, 2, 3, 1).float().numpy() image = image.cpu().permute(0, 2, 3, 1).float().numpy()
......
...@@ -358,7 +358,7 @@ class StableDiffusionKDiffusionPipeline(DiffusionPipeline, TextualInversionLoade ...@@ -358,7 +358,7 @@ class StableDiffusionKDiffusionPipeline(DiffusionPipeline, TextualInversionLoade
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
def decode_latents(self, latents): def decode_latents(self, latents):
latents = 1 / self.vae.config.scaling_factor * latents latents = 1 / self.vae.config.scaling_factor * latents
image = self.vae.decode(latents).sample image = self.vae.decode(latents, return_dict=False)[0]
image = (image / 2 + 0.5).clamp(0, 1) image = (image / 2 + 0.5).clamp(0, 1)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
image = image.cpu().permute(0, 2, 3, 1).float().numpy() image = image.cpu().permute(0, 2, 3, 1).float().numpy()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment