".github/vscode:/vscode.git/clone" did not exist on "caa02392d3f415d2cd466d6a21f8d2756d8b002e"
Unverified Commit d5acb411 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

Finalize ldm (#96)

* upload

* make checkpoint work

* finalize
parent 6cabc599
......@@ -7,7 +7,15 @@ from .utils import is_inflect_available, is_transformers_available, is_unidecode
__version__ = "0.0.4"
from .modeling_utils import ModelMixin
from .models import AutoencoderKL, NCSNpp, UNetLDMModel, UNetModel, UNetUnconditionalModel, VQModel
from .models import (
AutoencoderKL,
NCSNpp,
UNetConditionalModel,
UNetLDMModel,
UNetModel,
UNetUnconditionalModel,
VQModel,
)
from .pipeline_utils import DiffusionPipeline
from .pipelines import (
DDIMPipeline,
......
......@@ -17,6 +17,7 @@
# limitations under the License.
from .unet import UNetModel
from .unet_conditional import UNetConditionalModel
from .unet_glide import GlideSuperResUNetModel, GlideTextToImageUNetModel, GlideUNetModel
from .unet_ldm import UNetLDMModel
from .unet_sde_score_estimation import NCSNpp
......
......@@ -42,7 +42,7 @@ class AttentionBlockNew(nn.Module):
self.value = nn.Linear(channels, channels)
self.rescale_output_factor = rescale_output_factor
self.proj_attn = zero_module(nn.Linear(channels, channels, 1))
self.proj_attn = nn.Linear(channels, channels, 1)
def transpose_for_scores(self, projection: torch.Tensor) -> torch.Tensor:
new_projection_shape = projection.size()[:-1] + (self.num_heads, -1)
......@@ -147,6 +147,8 @@ class SpatialTransformer(nn.Module):
def __init__(self, in_channels, n_heads, d_head, depth=1, dropout=0.0, context_dim=None):
super().__init__()
self.n_heads = n_heads
self.d_head = d_head
self.in_channels = in_channels
inner_dim = n_heads * d_head
self.norm = torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
......@@ -160,7 +162,7 @@ class SpatialTransformer(nn.Module):
]
)
self.proj_out = zero_module(nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0))
self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
def forward(self, x, context=None):
# note: if no context is given, cross-attention defaults to self-attention
......@@ -175,6 +177,12 @@ class SpatialTransformer(nn.Module):
x = self.proj_out(x)
return x + x_in
def set_weight(self, layer):
self.norm = layer.norm
self.proj_in = layer.proj_in
self.transformer_blocks = layer.transformer_blocks
self.proj_out = layer.proj_out
class BasicTransformerBlock(nn.Module):
def __init__(self, dim, n_heads, d_head, dropout=0.0, context_dim=None, gated_ff=True, checkpoint=True):
......@@ -270,14 +278,15 @@ class FeedForward(nn.Module):
return self.net(x)
# TODO(Patrick) - this can and should be removed
def zero_module(module):
"""
Zero out the parameters of a module and return it.
"""
for p in module.parameters():
p.detach().zero_()
return module
# feedforward
class GEGLU(nn.Module):
def __init__(self, dim_in, dim_out):
super().__init__()
self.proj = nn.Linear(dim_in, dim_out * 2)
def forward(self, x):
x, gate = self.proj(x).chunk(2, dim=-1)
return x * F.gelu(gate)
# TODO(Patrick) - remove once all weights have been converted -> not needed anymore then
......@@ -298,17 +307,6 @@ def default(val, d):
return d() if isfunction(d) else d
# feedforward
class GEGLU(nn.Module):
def __init__(self, dim_in, dim_out):
super().__init__()
self.proj = nn.Linear(dim_in, dim_out * 2)
def forward(self, x):
x, gate = self.proj(x).chunk(2, dim=-1)
return x * F.gelu(gate)
# the main attention block that is used for all models
class AttentionBlock(nn.Module):
"""
......@@ -348,7 +346,7 @@ class AttentionBlock(nn.Module):
if encoder_channels is not None:
self.encoder_kv = nn.Conv1d(encoder_channels, channels * 2, 1)
self.proj = zero_module(nn.Conv1d(channels, channels, 1))
self.proj = nn.Conv1d(channels, channels, 1)
self.overwrite_qkv = overwrite_qkv
self.overwrite_linear = overwrite_linear
......@@ -370,7 +368,7 @@ class AttentionBlock(nn.Module):
self.GroupNorm_0 = nn.GroupNorm(num_groups=num_groups, num_channels=channels, eps=1e-6)
else:
self.proj_out = zero_module(nn.Conv1d(channels, channels, 1))
self.proj_out = nn.Conv1d(channels, channels, 1)
self.set_weights(self)
self.is_overwritten = False
......@@ -385,7 +383,7 @@ class AttentionBlock(nn.Module):
self.qkv.weight.data = qkv_weight
self.qkv.bias.data = qkv_bias
proj_out = zero_module(nn.Conv1d(self.channels, self.channels, 1))
proj_out = nn.Conv1d(self.channels, self.channels, 1)
proj_out.weight.data = module.proj_out.weight.data[:, :, :, 0]
proj_out.bias.data = module.proj_out.bias.data
......
This diff is collapsed.
......@@ -17,7 +17,7 @@ import numpy as np
import torch
from torch import nn
from .attention import AttentionBlockNew
from .attention import AttentionBlockNew, SpatialTransformer
from .resnet import Downsample2D, FirDownsample2D, FirUpsample2D, ResnetBlock, Upsample2D
......@@ -56,6 +56,18 @@ def get_down_block(
downsample_padding=downsample_padding,
attn_num_head_channels=attn_num_head_channels,
)
elif down_block_type == "UNetResCrossAttnDownBlock2D":
return UNetResCrossAttnDownBlock2D(
num_layers=num_layers,
in_channels=in_channels,
out_channels=out_channels,
temb_channels=temb_channels,
add_downsample=add_downsample,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
downsample_padding=downsample_padding,
attn_num_head_channels=attn_num_head_channels,
)
elif down_block_type == "UNetResSkipDownBlock2D":
return UNetResSkipDownBlock2D(
num_layers=num_layers,
......@@ -104,6 +116,18 @@ def get_up_block(
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
)
elif up_block_type == "UNetResCrossAttnUpBlock2D":
return UNetResCrossAttnUpBlock2D(
num_layers=num_layers,
in_channels=in_channels,
out_channels=out_channels,
prev_output_channel=prev_output_channel,
temb_channels=temb_channels,
add_upsample=add_upsample,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
attn_num_head_channels=attn_num_head_channels,
)
elif up_block_type == "UNetResAttnUpBlock2D":
return UNetResAttnUpBlock2D(
num_layers=num_layers,
......@@ -221,6 +245,83 @@ class UNetMidBlock2D(nn.Module):
return hidden_states
class UNetMidBlock2DCrossAttn(nn.Module):
def __init__(
self,
in_channels: int,
temb_channels: int,
dropout: float = 0.0,
num_layers: int = 1,
resnet_eps: float = 1e-6,
resnet_time_scale_shift: str = "default",
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
resnet_pre_norm: bool = True,
attn_num_head_channels=1,
attention_type="default",
output_scale_factor=1.0,
cross_attention_dim=1280,
**kwargs,
):
super().__init__()
self.attention_type = attention_type
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
# there is always at least one resnet
resnets = [
ResnetBlock(
in_channels=in_channels,
out_channels=in_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
)
]
attentions = []
for _ in range(num_layers):
attentions.append(
SpatialTransformer(
in_channels,
attn_num_head_channels,
in_channels // attn_num_head_channels,
depth=1,
context_dim=cross_attention_dim,
)
)
resnets.append(
ResnetBlock(
in_channels=in_channels,
out_channels=in_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
)
)
self.attentions = nn.ModuleList(attentions)
self.resnets = nn.ModuleList(resnets)
def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
hidden_states = self.resnets[0](hidden_states, temb)
for attn, resnet in zip(self.attentions, self.resnets[1:]):
hidden_states = attn(hidden_states, encoder_hidden_states)
hidden_states = resnet(hidden_states, temb)
return hidden_states
class UNetResAttnDownBlock2D(nn.Module):
def __init__(
self,
......@@ -302,6 +403,88 @@ class UNetResAttnDownBlock2D(nn.Module):
return hidden_states, output_states
class UNetResCrossAttnDownBlock2D(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
temb_channels: int,
dropout: float = 0.0,
num_layers: int = 1,
resnet_eps: float = 1e-6,
resnet_time_scale_shift: str = "default",
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
resnet_pre_norm: bool = True,
attn_num_head_channels=1,
cross_attention_dim=1280,
attention_type="default",
output_scale_factor=1.0,
downsample_padding=1,
add_downsample=True,
):
super().__init__()
resnets = []
attentions = []
self.attention_type = attention_type
for i in range(num_layers):
in_channels = in_channels if i == 0 else out_channels
resnets.append(
ResnetBlock(
in_channels=in_channels,
out_channels=out_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
)
)
attentions.append(
SpatialTransformer(
out_channels,
attn_num_head_channels,
out_channels // attn_num_head_channels,
depth=1,
context_dim=cross_attention_dim,
)
)
self.attentions = nn.ModuleList(attentions)
self.resnets = nn.ModuleList(resnets)
if add_downsample:
self.downsamplers = nn.ModuleList(
[
Downsample2D(
in_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
)
]
)
else:
self.downsamplers = None
def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
output_states = ()
for resnet, attn in zip(self.resnets, self.attentions):
hidden_states = resnet(hidden_states, temb)
hidden_states = attn(hidden_states, context=encoder_hidden_states)
output_states += (hidden_states,)
if self.downsamplers is not None:
for downsampler in self.downsamplers:
hidden_states = downsampler(hidden_states)
output_states += (hidden_states,)
return hidden_states, output_states
class UNetResDownBlock2D(nn.Module):
def __init__(
self,
......@@ -618,6 +801,86 @@ class UNetResAttnUpBlock2D(nn.Module):
return hidden_states
class UNetResCrossAttnUpBlock2D(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
prev_output_channel: int,
temb_channels: int,
dropout: float = 0.0,
num_layers: int = 1,
resnet_eps: float = 1e-6,
resnet_time_scale_shift: str = "default",
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
resnet_pre_norm: bool = True,
attn_num_head_channels=1,
cross_attention_dim=1280,
attention_type="default",
output_scale_factor=1.0,
downsample_padding=1,
add_upsample=True,
):
super().__init__()
resnets = []
attentions = []
self.attention_type = attention_type
for i in range(num_layers):
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
resnet_in_channels = prev_output_channel if i == 0 else out_channels
resnets.append(
ResnetBlock(
in_channels=resnet_in_channels + res_skip_channels,
out_channels=out_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
)
)
attentions.append(
SpatialTransformer(
out_channels,
attn_num_head_channels,
out_channels // attn_num_head_channels,
depth=1,
context_dim=cross_attention_dim,
)
)
self.attentions = nn.ModuleList(attentions)
self.resnets = nn.ModuleList(resnets)
if add_upsample:
self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
else:
self.upsamplers = None
def forward(self, hidden_states, res_hidden_states_tuple, temb=None, encoder_hidden_states=None):
for resnet, attn in zip(self.resnets, self.attentions):
# pop res hidden states
res_hidden_states = res_hidden_states_tuple[-1]
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
hidden_states = resnet(hidden_states, temb)
hidden_states = attn(hidden_states, context=encoder_hidden_states)
if self.upsamplers is not None:
for upsampler in self.upsamplers:
hidden_states = upsampler(hidden_states)
return hidden_states
class UNetResUpBlock2D(nn.Module):
def __init__(
self,
......@@ -765,8 +1028,6 @@ class UNetResAttnSkipUpBlock2D(nn.Module):
self.act = None
def forward(self, hidden_states, res_hidden_states_tuple, temb=None, skip_sample=None):
output_states = ()
for resnet in self.resnets:
# pop res hidden states
res_hidden_states = res_hidden_states_tuple[-1]
......@@ -864,8 +1125,6 @@ class UNetResSkipUpBlock2D(nn.Module):
self.act = None
def forward(self, hidden_states, res_hidden_states_tuple, temb=None, skip_sample=None):
output_states = ()
for resnet in self.resnets:
# pop res hidden states
res_hidden_states = res_hidden_states_tuple[-1]
......
from typing import Optional, Tuple, Union
import numpy as np
import torch
import torch.nn as nn
import torch.utils.checkpoint
......@@ -15,6 +14,69 @@ from transformers.utils import logging
from ...pipeline_utils import DiffusionPipeline
class LatentDiffusionPipeline(DiffusionPipeline):
def __init__(self, vqvae, bert, tokenizer, unet, scheduler):
super().__init__()
scheduler = scheduler.set_format("pt")
self.register_modules(vqvae=vqvae, bert=bert, tokenizer=tokenizer, unet=unet, scheduler=scheduler)
@torch.no_grad()
def __call__(
self,
prompt,
batch_size=1,
generator=None,
torch_device=None,
eta=0.0,
guidance_scale=1.0,
num_inference_steps=50,
):
# eta corresponds to η in paper and should be between [0, 1]
if torch_device is None:
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
self.unet.to(torch_device)
self.vqvae.to(torch_device)
self.bert.to(torch_device)
# get unconditional embeddings for classifier free guidence
if guidance_scale != 1.0:
uncond_input = self.tokenizer([""], padding="max_length", max_length=77, return_tensors="pt").to(
torch_device
)
uncond_embeddings = self.bert(uncond_input.input_ids)
# get text embedding
text_input = self.tokenizer(prompt, padding="max_length", max_length=77, return_tensors="pt").to(torch_device)
text_embedding = self.bert(text_input.input_ids)
image = torch.randn(
(batch_size, self.unet.in_channels, self.unet.image_size, self.unet.image_size),
generator=generator,
).to(torch_device)
self.scheduler.set_timesteps(num_inference_steps)
for t in tqdm.tqdm(self.scheduler.timesteps):
# 1. predict noise residual
pred_noise_t = self.unet(image, t, encoder_hidden_states=text_embedding)
if isinstance(pred_noise_t, dict):
pred_noise_t = pred_noise_t["sample"]
# 2. predict previous mean of image x_t-1 and add variance depending on eta
# do x_t -> x_t-1
image = self.scheduler.step(pred_noise_t, t, image, eta)["prev_sample"]
# scale and decode image with vae
image = 1 / 0.18215 * image
image = self.vqvae.decode(image)
image = torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)
return image
################################################################################
# Code for the text transformer model
################################################################################
......@@ -541,101 +603,4 @@ class LDMBertModel(LDMBertPreTrainedModel):
return_dict=return_dict,
)
sequence_output = outputs[0]
return sequence_output
class LatentDiffusionPipeline(DiffusionPipeline):
def __init__(self, vqvae, bert, tokenizer, unet, scheduler):
super().__init__()
scheduler = scheduler.set_format("pt")
self.register_modules(vqvae=vqvae, bert=bert, tokenizer=tokenizer, unet=unet, scheduler=scheduler)
@torch.no_grad()
def __call__(
self,
prompt,
batch_size=1,
generator=None,
torch_device=None,
eta=0.0,
guidance_scale=1.0,
num_inference_steps=50,
):
# eta corresponds to η in paper and should be between [0, 1]
if torch_device is None:
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
self.unet.to(torch_device)
self.vqvae.to(torch_device)
self.bert.to(torch_device)
# get unconditional embeddings for classifier free guidence
if guidance_scale != 1.0:
uncond_input = self.tokenizer([""], padding="max_length", max_length=77, return_tensors="pt").to(
torch_device
)
uncond_embeddings = self.bert(uncond_input.input_ids)
# get text embedding
text_input = self.tokenizer(prompt, padding="max_length", max_length=77, return_tensors="pt").to(torch_device)
text_embedding = self.bert(text_input.input_ids)
num_trained_timesteps = self.scheduler.config.timesteps
inference_step_times = range(0, num_trained_timesteps, num_trained_timesteps // num_inference_steps)
image = torch.randn(
(batch_size, self.unet.in_channels, self.unet.image_size, self.unet.image_size),
generator=generator,
).to(torch_device)
# See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
# Ideally, read DDIM paper in-detail understanding
# Notation (<variable name> -> <name in paper>
# - pred_noise_t -> e_theta(x_t, t)
# - pred_original_image -> f_theta(x_t, t) or x_0
# - std_dev_t -> sigma_t
# - eta -> η
# - pred_image_direction -> "direction pointingc to x_t"
# - pred_prev_image -> "x_t-1"
for t in tqdm(reversed(range(num_inference_steps)), total=num_inference_steps):
# guidance_scale of 1 means no guidance
if guidance_scale == 1.0:
image_in = image
context = text_embedding
timesteps = torch.tensor([inference_step_times[t]] * image.shape[0], device=torch_device)
else:
# for classifier free guidance, we need to do two forward passes
# here we concanate embedding and unconditioned embedding in a single batch
# to avoid doing two forward passes
image_in = torch.cat([image] * 2)
context = torch.cat([uncond_embeddings, text_embedding])
timesteps = torch.tensor([inference_step_times[t]] * image.shape[0], device=torch_device)
# 1. predict noise residual
pred_noise_t = self.unet(image_in, timesteps, context=context)
# perform guidance
if guidance_scale != 1.0:
pred_noise_t_uncond, pred_noise_t = pred_noise_t.chunk(2)
pred_noise_t = pred_noise_t_uncond + guidance_scale * (pred_noise_t - pred_noise_t_uncond)
# 2. predict previous mean of image x_t-1
pred_prev_image = self.scheduler.step(pred_noise_t, image, t, num_inference_steps, eta)
# 3. optionally sample variance
variance = 0
if eta > 0:
noise = torch.randn(image.shape, generator=generator).to(image.device)
variance = self.scheduler.get_variance(t, num_inference_steps).sqrt() * eta * noise
# 4. set current image to prev_image: x_t -> x_t-1
image = pred_prev_image + variance
# scale and decode image with vae
image = 1 / 0.18215 * image
image = self.vqvae.decode(image)
image = torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)
return image
return sequence_output
\ No newline at end of file
......@@ -40,14 +40,17 @@ from diffusers import (
ScoreSdeVeScheduler,
ScoreSdeVpPipeline,
ScoreSdeVpScheduler,
UNetConditionalModel,
UNetLDMModel,
UNetUnconditionalModel,
VQModel,
)
from diffusers.configuration_utils import ConfigMixin
from diffusers.pipeline_utils import DiffusionPipeline
from diffusers.pipelines.latent_diffusion.pipeline_latent_diffusion import LDMBertModel
from diffusers.testing_utils import floats_tensor, slow, torch_device
from diffusers.training_utils import EMAModel
from transformers import BertTokenizer
torch.backends.cuda.matmul.allow_tf32 = False
......@@ -827,7 +830,7 @@ class VQModelTests(ModelTesterMixin, unittest.TestCase):
self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-2))
class AutoEncoderKLTests(ModelTesterMixin, unittest.TestCase):
class AutoencoderKLTests(ModelTesterMixin, unittest.TestCase):
model_class = AutoencoderKL
@property
......@@ -1026,10 +1029,8 @@ class PipelineTesterMixin(unittest.TestCase):
assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2
@slow
@unittest.skip("Skipping for now as it takes too long")
def test_ldm_text2img(self):
model_id = "fusing/latent-diffusion-text2im-large"
ldm = LatentDiffusionPipeline.from_pretrained(model_id)
ldm = LatentDiffusionPipeline.from_pretrained("/home/patrick/latent-diffusion-text2im-large")
prompt = "A painting of a squirrel eating a burger"
generator = torch.manual_seed(0)
......@@ -1043,8 +1044,7 @@ class PipelineTesterMixin(unittest.TestCase):
@slow
def test_ldm_text2img_fast(self):
model_id = "fusing/latent-diffusion-text2im-large"
ldm = LatentDiffusionPipeline.from_pretrained(model_id)
ldm = LatentDiffusionPipeline.from_pretrained("/home/patrick/latent-diffusion-text2im-large")
prompt = "A painting of a squirrel eating a burger"
generator = torch.manual_seed(0)
......@@ -1074,6 +1074,7 @@ class PipelineTesterMixin(unittest.TestCase):
@slow
def test_score_sde_ve_pipeline(self):
model = UNetUnconditionalModel.from_pretrained("fusing/ffhq_ncsnpp", sde=True)
model = UNetUnconditionalModel.from_pretrained("google/ffhq_ncsnpp")
torch.manual_seed(0)
if torch.cuda.is_available():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment