Unverified Commit d5acb411 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

Finalize ldm (#96)

* upload

* make checkpoint work

* finalize
parent 6cabc599
...@@ -7,7 +7,15 @@ from .utils import is_inflect_available, is_transformers_available, is_unidecode ...@@ -7,7 +7,15 @@ from .utils import is_inflect_available, is_transformers_available, is_unidecode
__version__ = "0.0.4" __version__ = "0.0.4"
from .modeling_utils import ModelMixin from .modeling_utils import ModelMixin
from .models import AutoencoderKL, NCSNpp, UNetLDMModel, UNetModel, UNetUnconditionalModel, VQModel from .models import (
AutoencoderKL,
NCSNpp,
UNetConditionalModel,
UNetLDMModel,
UNetModel,
UNetUnconditionalModel,
VQModel,
)
from .pipeline_utils import DiffusionPipeline from .pipeline_utils import DiffusionPipeline
from .pipelines import ( from .pipelines import (
DDIMPipeline, DDIMPipeline,
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
# limitations under the License. # limitations under the License.
from .unet import UNetModel from .unet import UNetModel
from .unet_conditional import UNetConditionalModel
from .unet_glide import GlideSuperResUNetModel, GlideTextToImageUNetModel, GlideUNetModel from .unet_glide import GlideSuperResUNetModel, GlideTextToImageUNetModel, GlideUNetModel
from .unet_ldm import UNetLDMModel from .unet_ldm import UNetLDMModel
from .unet_sde_score_estimation import NCSNpp from .unet_sde_score_estimation import NCSNpp
......
...@@ -42,7 +42,7 @@ class AttentionBlockNew(nn.Module): ...@@ -42,7 +42,7 @@ class AttentionBlockNew(nn.Module):
self.value = nn.Linear(channels, channels) self.value = nn.Linear(channels, channels)
self.rescale_output_factor = rescale_output_factor self.rescale_output_factor = rescale_output_factor
self.proj_attn = zero_module(nn.Linear(channels, channels, 1)) self.proj_attn = nn.Linear(channels, channels, 1)
def transpose_for_scores(self, projection: torch.Tensor) -> torch.Tensor: def transpose_for_scores(self, projection: torch.Tensor) -> torch.Tensor:
new_projection_shape = projection.size()[:-1] + (self.num_heads, -1) new_projection_shape = projection.size()[:-1] + (self.num_heads, -1)
...@@ -147,6 +147,8 @@ class SpatialTransformer(nn.Module): ...@@ -147,6 +147,8 @@ class SpatialTransformer(nn.Module):
def __init__(self, in_channels, n_heads, d_head, depth=1, dropout=0.0, context_dim=None): def __init__(self, in_channels, n_heads, d_head, depth=1, dropout=0.0, context_dim=None):
super().__init__() super().__init__()
self.n_heads = n_heads
self.d_head = d_head
self.in_channels = in_channels self.in_channels = in_channels
inner_dim = n_heads * d_head inner_dim = n_heads * d_head
self.norm = torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) self.norm = torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
...@@ -160,7 +162,7 @@ class SpatialTransformer(nn.Module): ...@@ -160,7 +162,7 @@ class SpatialTransformer(nn.Module):
] ]
) )
self.proj_out = zero_module(nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)) self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
def forward(self, x, context=None): def forward(self, x, context=None):
# note: if no context is given, cross-attention defaults to self-attention # note: if no context is given, cross-attention defaults to self-attention
...@@ -175,6 +177,12 @@ class SpatialTransformer(nn.Module): ...@@ -175,6 +177,12 @@ class SpatialTransformer(nn.Module):
x = self.proj_out(x) x = self.proj_out(x)
return x + x_in return x + x_in
def set_weight(self, layer):
self.norm = layer.norm
self.proj_in = layer.proj_in
self.transformer_blocks = layer.transformer_blocks
self.proj_out = layer.proj_out
class BasicTransformerBlock(nn.Module): class BasicTransformerBlock(nn.Module):
def __init__(self, dim, n_heads, d_head, dropout=0.0, context_dim=None, gated_ff=True, checkpoint=True): def __init__(self, dim, n_heads, d_head, dropout=0.0, context_dim=None, gated_ff=True, checkpoint=True):
...@@ -270,14 +278,15 @@ class FeedForward(nn.Module): ...@@ -270,14 +278,15 @@ class FeedForward(nn.Module):
return self.net(x) return self.net(x)
# TODO(Patrick) - this can and should be removed # feedforward
def zero_module(module): class GEGLU(nn.Module):
""" def __init__(self, dim_in, dim_out):
Zero out the parameters of a module and return it. super().__init__()
""" self.proj = nn.Linear(dim_in, dim_out * 2)
for p in module.parameters():
p.detach().zero_() def forward(self, x):
return module x, gate = self.proj(x).chunk(2, dim=-1)
return x * F.gelu(gate)
# TODO(Patrick) - remove once all weights have been converted -> not needed anymore then # TODO(Patrick) - remove once all weights have been converted -> not needed anymore then
...@@ -298,17 +307,6 @@ def default(val, d): ...@@ -298,17 +307,6 @@ def default(val, d):
return d() if isfunction(d) else d return d() if isfunction(d) else d
# feedforward
class GEGLU(nn.Module):
def __init__(self, dim_in, dim_out):
super().__init__()
self.proj = nn.Linear(dim_in, dim_out * 2)
def forward(self, x):
x, gate = self.proj(x).chunk(2, dim=-1)
return x * F.gelu(gate)
# the main attention block that is used for all models # the main attention block that is used for all models
class AttentionBlock(nn.Module): class AttentionBlock(nn.Module):
""" """
...@@ -348,7 +346,7 @@ class AttentionBlock(nn.Module): ...@@ -348,7 +346,7 @@ class AttentionBlock(nn.Module):
if encoder_channels is not None: if encoder_channels is not None:
self.encoder_kv = nn.Conv1d(encoder_channels, channels * 2, 1) self.encoder_kv = nn.Conv1d(encoder_channels, channels * 2, 1)
self.proj = zero_module(nn.Conv1d(channels, channels, 1)) self.proj = nn.Conv1d(channels, channels, 1)
self.overwrite_qkv = overwrite_qkv self.overwrite_qkv = overwrite_qkv
self.overwrite_linear = overwrite_linear self.overwrite_linear = overwrite_linear
...@@ -370,7 +368,7 @@ class AttentionBlock(nn.Module): ...@@ -370,7 +368,7 @@ class AttentionBlock(nn.Module):
self.GroupNorm_0 = nn.GroupNorm(num_groups=num_groups, num_channels=channels, eps=1e-6) self.GroupNorm_0 = nn.GroupNorm(num_groups=num_groups, num_channels=channels, eps=1e-6)
else: else:
self.proj_out = zero_module(nn.Conv1d(channels, channels, 1)) self.proj_out = nn.Conv1d(channels, channels, 1)
self.set_weights(self) self.set_weights(self)
self.is_overwritten = False self.is_overwritten = False
...@@ -385,7 +383,7 @@ class AttentionBlock(nn.Module): ...@@ -385,7 +383,7 @@ class AttentionBlock(nn.Module):
self.qkv.weight.data = qkv_weight self.qkv.weight.data = qkv_weight
self.qkv.bias.data = qkv_bias self.qkv.bias.data = qkv_bias
proj_out = zero_module(nn.Conv1d(self.channels, self.channels, 1)) proj_out = nn.Conv1d(self.channels, self.channels, 1)
proj_out.weight.data = module.proj_out.weight.data[:, :, :, 0] proj_out.weight.data = module.proj_out.weight.data[:, :, :, 0]
proj_out.bias.data = module.proj_out.bias.data proj_out.bias.data = module.proj_out.bias.data
......
This diff is collapsed.
...@@ -17,7 +17,7 @@ import numpy as np ...@@ -17,7 +17,7 @@ import numpy as np
import torch import torch
from torch import nn from torch import nn
from .attention import AttentionBlockNew from .attention import AttentionBlockNew, SpatialTransformer
from .resnet import Downsample2D, FirDownsample2D, FirUpsample2D, ResnetBlock, Upsample2D from .resnet import Downsample2D, FirDownsample2D, FirUpsample2D, ResnetBlock, Upsample2D
...@@ -56,6 +56,18 @@ def get_down_block( ...@@ -56,6 +56,18 @@ def get_down_block(
downsample_padding=downsample_padding, downsample_padding=downsample_padding,
attn_num_head_channels=attn_num_head_channels, attn_num_head_channels=attn_num_head_channels,
) )
elif down_block_type == "UNetResCrossAttnDownBlock2D":
return UNetResCrossAttnDownBlock2D(
num_layers=num_layers,
in_channels=in_channels,
out_channels=out_channels,
temb_channels=temb_channels,
add_downsample=add_downsample,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
downsample_padding=downsample_padding,
attn_num_head_channels=attn_num_head_channels,
)
elif down_block_type == "UNetResSkipDownBlock2D": elif down_block_type == "UNetResSkipDownBlock2D":
return UNetResSkipDownBlock2D( return UNetResSkipDownBlock2D(
num_layers=num_layers, num_layers=num_layers,
...@@ -104,6 +116,18 @@ def get_up_block( ...@@ -104,6 +116,18 @@ def get_up_block(
resnet_eps=resnet_eps, resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn, resnet_act_fn=resnet_act_fn,
) )
elif up_block_type == "UNetResCrossAttnUpBlock2D":
return UNetResCrossAttnUpBlock2D(
num_layers=num_layers,
in_channels=in_channels,
out_channels=out_channels,
prev_output_channel=prev_output_channel,
temb_channels=temb_channels,
add_upsample=add_upsample,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
attn_num_head_channels=attn_num_head_channels,
)
elif up_block_type == "UNetResAttnUpBlock2D": elif up_block_type == "UNetResAttnUpBlock2D":
return UNetResAttnUpBlock2D( return UNetResAttnUpBlock2D(
num_layers=num_layers, num_layers=num_layers,
...@@ -221,6 +245,83 @@ class UNetMidBlock2D(nn.Module): ...@@ -221,6 +245,83 @@ class UNetMidBlock2D(nn.Module):
return hidden_states return hidden_states
class UNetMidBlock2DCrossAttn(nn.Module):
def __init__(
self,
in_channels: int,
temb_channels: int,
dropout: float = 0.0,
num_layers: int = 1,
resnet_eps: float = 1e-6,
resnet_time_scale_shift: str = "default",
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
resnet_pre_norm: bool = True,
attn_num_head_channels=1,
attention_type="default",
output_scale_factor=1.0,
cross_attention_dim=1280,
**kwargs,
):
super().__init__()
self.attention_type = attention_type
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
# there is always at least one resnet
resnets = [
ResnetBlock(
in_channels=in_channels,
out_channels=in_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
)
]
attentions = []
for _ in range(num_layers):
attentions.append(
SpatialTransformer(
in_channels,
attn_num_head_channels,
in_channels // attn_num_head_channels,
depth=1,
context_dim=cross_attention_dim,
)
)
resnets.append(
ResnetBlock(
in_channels=in_channels,
out_channels=in_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
)
)
self.attentions = nn.ModuleList(attentions)
self.resnets = nn.ModuleList(resnets)
def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
hidden_states = self.resnets[0](hidden_states, temb)
for attn, resnet in zip(self.attentions, self.resnets[1:]):
hidden_states = attn(hidden_states, encoder_hidden_states)
hidden_states = resnet(hidden_states, temb)
return hidden_states
class UNetResAttnDownBlock2D(nn.Module): class UNetResAttnDownBlock2D(nn.Module):
def __init__( def __init__(
self, self,
...@@ -302,6 +403,88 @@ class UNetResAttnDownBlock2D(nn.Module): ...@@ -302,6 +403,88 @@ class UNetResAttnDownBlock2D(nn.Module):
return hidden_states, output_states return hidden_states, output_states
class UNetResCrossAttnDownBlock2D(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
temb_channels: int,
dropout: float = 0.0,
num_layers: int = 1,
resnet_eps: float = 1e-6,
resnet_time_scale_shift: str = "default",
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
resnet_pre_norm: bool = True,
attn_num_head_channels=1,
cross_attention_dim=1280,
attention_type="default",
output_scale_factor=1.0,
downsample_padding=1,
add_downsample=True,
):
super().__init__()
resnets = []
attentions = []
self.attention_type = attention_type
for i in range(num_layers):
in_channels = in_channels if i == 0 else out_channels
resnets.append(
ResnetBlock(
in_channels=in_channels,
out_channels=out_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
)
)
attentions.append(
SpatialTransformer(
out_channels,
attn_num_head_channels,
out_channels // attn_num_head_channels,
depth=1,
context_dim=cross_attention_dim,
)
)
self.attentions = nn.ModuleList(attentions)
self.resnets = nn.ModuleList(resnets)
if add_downsample:
self.downsamplers = nn.ModuleList(
[
Downsample2D(
in_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
)
]
)
else:
self.downsamplers = None
def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
output_states = ()
for resnet, attn in zip(self.resnets, self.attentions):
hidden_states = resnet(hidden_states, temb)
hidden_states = attn(hidden_states, context=encoder_hidden_states)
output_states += (hidden_states,)
if self.downsamplers is not None:
for downsampler in self.downsamplers:
hidden_states = downsampler(hidden_states)
output_states += (hidden_states,)
return hidden_states, output_states
class UNetResDownBlock2D(nn.Module): class UNetResDownBlock2D(nn.Module):
def __init__( def __init__(
self, self,
...@@ -618,6 +801,86 @@ class UNetResAttnUpBlock2D(nn.Module): ...@@ -618,6 +801,86 @@ class UNetResAttnUpBlock2D(nn.Module):
return hidden_states return hidden_states
class UNetResCrossAttnUpBlock2D(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
prev_output_channel: int,
temb_channels: int,
dropout: float = 0.0,
num_layers: int = 1,
resnet_eps: float = 1e-6,
resnet_time_scale_shift: str = "default",
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
resnet_pre_norm: bool = True,
attn_num_head_channels=1,
cross_attention_dim=1280,
attention_type="default",
output_scale_factor=1.0,
downsample_padding=1,
add_upsample=True,
):
super().__init__()
resnets = []
attentions = []
self.attention_type = attention_type
for i in range(num_layers):
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
resnet_in_channels = prev_output_channel if i == 0 else out_channels
resnets.append(
ResnetBlock(
in_channels=resnet_in_channels + res_skip_channels,
out_channels=out_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
)
)
attentions.append(
SpatialTransformer(
out_channels,
attn_num_head_channels,
out_channels // attn_num_head_channels,
depth=1,
context_dim=cross_attention_dim,
)
)
self.attentions = nn.ModuleList(attentions)
self.resnets = nn.ModuleList(resnets)
if add_upsample:
self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
else:
self.upsamplers = None
def forward(self, hidden_states, res_hidden_states_tuple, temb=None, encoder_hidden_states=None):
for resnet, attn in zip(self.resnets, self.attentions):
# pop res hidden states
res_hidden_states = res_hidden_states_tuple[-1]
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
hidden_states = resnet(hidden_states, temb)
hidden_states = attn(hidden_states, context=encoder_hidden_states)
if self.upsamplers is not None:
for upsampler in self.upsamplers:
hidden_states = upsampler(hidden_states)
return hidden_states
class UNetResUpBlock2D(nn.Module): class UNetResUpBlock2D(nn.Module):
def __init__( def __init__(
self, self,
...@@ -765,8 +1028,6 @@ class UNetResAttnSkipUpBlock2D(nn.Module): ...@@ -765,8 +1028,6 @@ class UNetResAttnSkipUpBlock2D(nn.Module):
self.act = None self.act = None
def forward(self, hidden_states, res_hidden_states_tuple, temb=None, skip_sample=None): def forward(self, hidden_states, res_hidden_states_tuple, temb=None, skip_sample=None):
output_states = ()
for resnet in self.resnets: for resnet in self.resnets:
# pop res hidden states # pop res hidden states
res_hidden_states = res_hidden_states_tuple[-1] res_hidden_states = res_hidden_states_tuple[-1]
...@@ -864,8 +1125,6 @@ class UNetResSkipUpBlock2D(nn.Module): ...@@ -864,8 +1125,6 @@ class UNetResSkipUpBlock2D(nn.Module):
self.act = None self.act = None
def forward(self, hidden_states, res_hidden_states_tuple, temb=None, skip_sample=None): def forward(self, hidden_states, res_hidden_states_tuple, temb=None, skip_sample=None):
output_states = ()
for resnet in self.resnets: for resnet in self.resnets:
# pop res hidden states # pop res hidden states
res_hidden_states = res_hidden_states_tuple[-1] res_hidden_states = res_hidden_states_tuple[-1]
......
from typing import Optional, Tuple, Union from typing import Optional, Tuple, Union
import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.utils.checkpoint import torch.utils.checkpoint
...@@ -15,6 +14,69 @@ from transformers.utils import logging ...@@ -15,6 +14,69 @@ from transformers.utils import logging
from ...pipeline_utils import DiffusionPipeline from ...pipeline_utils import DiffusionPipeline
class LatentDiffusionPipeline(DiffusionPipeline):
def __init__(self, vqvae, bert, tokenizer, unet, scheduler):
super().__init__()
scheduler = scheduler.set_format("pt")
self.register_modules(vqvae=vqvae, bert=bert, tokenizer=tokenizer, unet=unet, scheduler=scheduler)
@torch.no_grad()
def __call__(
self,
prompt,
batch_size=1,
generator=None,
torch_device=None,
eta=0.0,
guidance_scale=1.0,
num_inference_steps=50,
):
# eta corresponds to η in paper and should be between [0, 1]
if torch_device is None:
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
self.unet.to(torch_device)
self.vqvae.to(torch_device)
self.bert.to(torch_device)
# get unconditional embeddings for classifier free guidence
if guidance_scale != 1.0:
uncond_input = self.tokenizer([""], padding="max_length", max_length=77, return_tensors="pt").to(
torch_device
)
uncond_embeddings = self.bert(uncond_input.input_ids)
# get text embedding
text_input = self.tokenizer(prompt, padding="max_length", max_length=77, return_tensors="pt").to(torch_device)
text_embedding = self.bert(text_input.input_ids)
image = torch.randn(
(batch_size, self.unet.in_channels, self.unet.image_size, self.unet.image_size),
generator=generator,
).to(torch_device)
self.scheduler.set_timesteps(num_inference_steps)
for t in tqdm.tqdm(self.scheduler.timesteps):
# 1. predict noise residual
pred_noise_t = self.unet(image, t, encoder_hidden_states=text_embedding)
if isinstance(pred_noise_t, dict):
pred_noise_t = pred_noise_t["sample"]
# 2. predict previous mean of image x_t-1 and add variance depending on eta
# do x_t -> x_t-1
image = self.scheduler.step(pred_noise_t, t, image, eta)["prev_sample"]
# scale and decode image with vae
image = 1 / 0.18215 * image
image = self.vqvae.decode(image)
image = torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)
return image
################################################################################ ################################################################################
# Code for the text transformer model # Code for the text transformer model
################################################################################ ################################################################################
...@@ -542,100 +604,3 @@ class LDMBertModel(LDMBertPreTrainedModel): ...@@ -542,100 +604,3 @@ class LDMBertModel(LDMBertPreTrainedModel):
) )
sequence_output = outputs[0] sequence_output = outputs[0]
return sequence_output return sequence_output
\ No newline at end of file
class LatentDiffusionPipeline(DiffusionPipeline):
def __init__(self, vqvae, bert, tokenizer, unet, scheduler):
super().__init__()
scheduler = scheduler.set_format("pt")
self.register_modules(vqvae=vqvae, bert=bert, tokenizer=tokenizer, unet=unet, scheduler=scheduler)
@torch.no_grad()
def __call__(
self,
prompt,
batch_size=1,
generator=None,
torch_device=None,
eta=0.0,
guidance_scale=1.0,
num_inference_steps=50,
):
# eta corresponds to η in paper and should be between [0, 1]
if torch_device is None:
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
self.unet.to(torch_device)
self.vqvae.to(torch_device)
self.bert.to(torch_device)
# get unconditional embeddings for classifier free guidence
if guidance_scale != 1.0:
uncond_input = self.tokenizer([""], padding="max_length", max_length=77, return_tensors="pt").to(
torch_device
)
uncond_embeddings = self.bert(uncond_input.input_ids)
# get text embedding
text_input = self.tokenizer(prompt, padding="max_length", max_length=77, return_tensors="pt").to(torch_device)
text_embedding = self.bert(text_input.input_ids)
num_trained_timesteps = self.scheduler.config.timesteps
inference_step_times = range(0, num_trained_timesteps, num_trained_timesteps // num_inference_steps)
image = torch.randn(
(batch_size, self.unet.in_channels, self.unet.image_size, self.unet.image_size),
generator=generator,
).to(torch_device)
# See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
# Ideally, read DDIM paper in-detail understanding
# Notation (<variable name> -> <name in paper>
# - pred_noise_t -> e_theta(x_t, t)
# - pred_original_image -> f_theta(x_t, t) or x_0
# - std_dev_t -> sigma_t
# - eta -> η
# - pred_image_direction -> "direction pointingc to x_t"
# - pred_prev_image -> "x_t-1"
for t in tqdm(reversed(range(num_inference_steps)), total=num_inference_steps):
# guidance_scale of 1 means no guidance
if guidance_scale == 1.0:
image_in = image
context = text_embedding
timesteps = torch.tensor([inference_step_times[t]] * image.shape[0], device=torch_device)
else:
# for classifier free guidance, we need to do two forward passes
# here we concanate embedding and unconditioned embedding in a single batch
# to avoid doing two forward passes
image_in = torch.cat([image] * 2)
context = torch.cat([uncond_embeddings, text_embedding])
timesteps = torch.tensor([inference_step_times[t]] * image.shape[0], device=torch_device)
# 1. predict noise residual
pred_noise_t = self.unet(image_in, timesteps, context=context)
# perform guidance
if guidance_scale != 1.0:
pred_noise_t_uncond, pred_noise_t = pred_noise_t.chunk(2)
pred_noise_t = pred_noise_t_uncond + guidance_scale * (pred_noise_t - pred_noise_t_uncond)
# 2. predict previous mean of image x_t-1
pred_prev_image = self.scheduler.step(pred_noise_t, image, t, num_inference_steps, eta)
# 3. optionally sample variance
variance = 0
if eta > 0:
noise = torch.randn(image.shape, generator=generator).to(image.device)
variance = self.scheduler.get_variance(t, num_inference_steps).sqrt() * eta * noise
# 4. set current image to prev_image: x_t -> x_t-1
image = pred_prev_image + variance
# scale and decode image with vae
image = 1 / 0.18215 * image
image = self.vqvae.decode(image)
image = torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)
return image
...@@ -40,14 +40,17 @@ from diffusers import ( ...@@ -40,14 +40,17 @@ from diffusers import (
ScoreSdeVeScheduler, ScoreSdeVeScheduler,
ScoreSdeVpPipeline, ScoreSdeVpPipeline,
ScoreSdeVpScheduler, ScoreSdeVpScheduler,
UNetConditionalModel,
UNetLDMModel, UNetLDMModel,
UNetUnconditionalModel, UNetUnconditionalModel,
VQModel, VQModel,
) )
from diffusers.configuration_utils import ConfigMixin from diffusers.configuration_utils import ConfigMixin
from diffusers.pipeline_utils import DiffusionPipeline from diffusers.pipeline_utils import DiffusionPipeline
from diffusers.pipelines.latent_diffusion.pipeline_latent_diffusion import LDMBertModel
from diffusers.testing_utils import floats_tensor, slow, torch_device from diffusers.testing_utils import floats_tensor, slow, torch_device
from diffusers.training_utils import EMAModel from diffusers.training_utils import EMAModel
from transformers import BertTokenizer
torch.backends.cuda.matmul.allow_tf32 = False torch.backends.cuda.matmul.allow_tf32 = False
...@@ -827,7 +830,7 @@ class VQModelTests(ModelTesterMixin, unittest.TestCase): ...@@ -827,7 +830,7 @@ class VQModelTests(ModelTesterMixin, unittest.TestCase):
self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-2)) self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-2))
class AutoEncoderKLTests(ModelTesterMixin, unittest.TestCase): class AutoencoderKLTests(ModelTesterMixin, unittest.TestCase):
model_class = AutoencoderKL model_class = AutoencoderKL
@property @property
...@@ -1026,10 +1029,8 @@ class PipelineTesterMixin(unittest.TestCase): ...@@ -1026,10 +1029,8 @@ class PipelineTesterMixin(unittest.TestCase):
assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2 assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2
@slow @slow
@unittest.skip("Skipping for now as it takes too long")
def test_ldm_text2img(self): def test_ldm_text2img(self):
model_id = "fusing/latent-diffusion-text2im-large" ldm = LatentDiffusionPipeline.from_pretrained("/home/patrick/latent-diffusion-text2im-large")
ldm = LatentDiffusionPipeline.from_pretrained(model_id)
prompt = "A painting of a squirrel eating a burger" prompt = "A painting of a squirrel eating a burger"
generator = torch.manual_seed(0) generator = torch.manual_seed(0)
...@@ -1043,8 +1044,7 @@ class PipelineTesterMixin(unittest.TestCase): ...@@ -1043,8 +1044,7 @@ class PipelineTesterMixin(unittest.TestCase):
@slow @slow
def test_ldm_text2img_fast(self): def test_ldm_text2img_fast(self):
model_id = "fusing/latent-diffusion-text2im-large" ldm = LatentDiffusionPipeline.from_pretrained("/home/patrick/latent-diffusion-text2im-large")
ldm = LatentDiffusionPipeline.from_pretrained(model_id)
prompt = "A painting of a squirrel eating a burger" prompt = "A painting of a squirrel eating a burger"
generator = torch.manual_seed(0) generator = torch.manual_seed(0)
...@@ -1074,6 +1074,7 @@ class PipelineTesterMixin(unittest.TestCase): ...@@ -1074,6 +1074,7 @@ class PipelineTesterMixin(unittest.TestCase):
@slow @slow
def test_score_sde_ve_pipeline(self): def test_score_sde_ve_pipeline(self):
model = UNetUnconditionalModel.from_pretrained("fusing/ffhq_ncsnpp", sde=True) model = UNetUnconditionalModel.from_pretrained("fusing/ffhq_ncsnpp", sde=True)
model = UNetUnconditionalModel.from_pretrained("google/ffhq_ncsnpp")
torch.manual_seed(0) torch.manual_seed(0)
if torch.cuda.is_available(): if torch.cuda.is_available():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment