Unverified Commit 6c4c6a04 authored by Fazzie-Maqianli's avatar Fazzie-Maqianli Committed by GitHub
Browse files

Merge pull request #2120 from Fazziekey/example/stablediffusion-v2

[example] support stable diffusion v2
parents 5efda697 cea4292a
import torch
import numpy as np
def append_dims(x, target_dims):
"""Appends dimensions to the end of a tensor until it has target_dims dimensions.
From https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/utils.py"""
dims_to_append = target_dims - x.ndim
if dims_to_append < 0:
raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less')
return x[(...,) + (None,) * dims_to_append]
def norm_thresholding(x0, value):
s = append_dims(x0.pow(2).flatten(1).mean(1).sqrt().clamp(min=value), x0.ndim)
return x0 * (value / s)
def spatial_norm_thresholding(x0, value):
# b c h w
s = x0.pow(2).mean(1, keepdim=True).sqrt().clamp(min=value)
return x0 * (value / s)
\ No newline at end of file
import torch
import torch.nn as nn
import numpy as np
from functools import partial
from ldm.modules.diffusionmodules.util import extract_into_tensor, make_beta_schedule
from ldm.util import default
class AbstractLowScaleModel(nn.Module):
# for concatenating a downsampled image to the latent representation
def __init__(self, noise_schedule_config=None):
super(AbstractLowScaleModel, self).__init__()
if noise_schedule_config is not None:
self.register_schedule(**noise_schedule_config)
def register_schedule(self, beta_schedule="linear", timesteps=1000,
linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end,
cosine_s=cosine_s)
alphas = 1. - betas
alphas_cumprod = np.cumprod(alphas, axis=0)
alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
timesteps, = betas.shape
self.num_timesteps = int(timesteps)
self.linear_start = linear_start
self.linear_end = linear_end
assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep'
to_torch = partial(torch.tensor, dtype=torch.float32)
self.register_buffer('betas', to_torch(betas))
self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev))
# calculations for diffusion q(x_t | x_{t-1}) and others
self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod)))
self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod)))
self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1)))
def q_sample(self, x_start, t, noise=None):
noise = default(noise, lambda: torch.randn_like(x_start))
return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise)
def forward(self, x):
return x, None
def decode(self, x):
return x
class SimpleImageConcat(AbstractLowScaleModel):
# no noise level conditioning
def __init__(self):
super(SimpleImageConcat, self).__init__(noise_schedule_config=None)
self.max_noise_level = 0
def forward(self, x):
# fix to constant noise level
return x, torch.zeros(x.shape[0], device=x.device).long()
class ImageConcatWithNoiseAugmentation(AbstractLowScaleModel):
def __init__(self, noise_schedule_config, max_noise_level=1000, to_cuda=False):
super().__init__(noise_schedule_config=noise_schedule_config)
self.max_noise_level = max_noise_level
def forward(self, x, noise_level=None):
if noise_level is None:
noise_level = torch.randint(0, self.max_noise_level, (x.shape[0],), device=x.device).long()
else:
assert isinstance(noise_level, torch.Tensor)
z = self.q_sample(x, noise_level)
return z, noise_level
...@@ -122,7 +122,9 @@ class CheckpointFunction(torch.autograd.Function): ...@@ -122,7 +122,9 @@ class CheckpointFunction(torch.autograd.Function):
ctx.run_function = run_function ctx.run_function = run_function
ctx.input_tensors = list(args[:length]) ctx.input_tensors = list(args[:length])
ctx.input_params = list(args[length:]) ctx.input_params = list(args[length:])
ctx.gpu_autocast_kwargs = {"enabled": torch.is_autocast_enabled(),
"dtype": torch.get_autocast_gpu_dtype(),
"cache_enabled": torch.is_autocast_cache_enabled()}
with torch.no_grad(): with torch.no_grad():
output_tensors = ctx.run_function(*ctx.input_tensors) output_tensors = ctx.run_function(*ctx.input_tensors)
return output_tensors return output_tensors
...@@ -130,7 +132,8 @@ class CheckpointFunction(torch.autograd.Function): ...@@ -130,7 +132,8 @@ class CheckpointFunction(torch.autograd.Function):
@staticmethod @staticmethod
def backward(ctx, *output_grads): def backward(ctx, *output_grads):
ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
with torch.enable_grad(): with torch.enable_grad(), \
torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs):
# Fixes a bug where the first op in run_function modifies the # Fixes a bug where the first op in run_function modifies the
# Tensor storage in place, which is not allowed for detach()'d # Tensor storage in place, which is not allowed for detach()'d
# Tensors. # Tensors.
...@@ -148,7 +151,7 @@ class CheckpointFunction(torch.autograd.Function): ...@@ -148,7 +151,7 @@ class CheckpointFunction(torch.autograd.Function):
return (None, None) + input_grads return (None, None) + input_grads
def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False, use_fp16=True): def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
""" """
Create sinusoidal timestep embeddings. Create sinusoidal timestep embeddings.
:param timesteps: a 1-D Tensor of N indices, one per batch element. :param timesteps: a 1-D Tensor of N indices, one per batch element.
...@@ -168,10 +171,7 @@ def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False, use_ ...@@ -168,10 +171,7 @@ def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False, use_
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
else: else:
embedding = repeat(timesteps, 'b -> b d', d=dim) embedding = repeat(timesteps, 'b -> b d', d=dim)
if use_fp16: return embedding
return embedding.half()
else:
return embedding
def zero_module(module): def zero_module(module):
...@@ -199,16 +199,14 @@ def mean_flat(tensor): ...@@ -199,16 +199,14 @@ def mean_flat(tensor):
return tensor.mean(dim=list(range(1, len(tensor.shape)))) return tensor.mean(dim=list(range(1, len(tensor.shape))))
def normalization(channels, precision=16): def normalization(channels):
""" """
Make a standard normalization layer. Make a standard normalization layer.
:param channels: number of input channels. :param channels: number of input channels.
:return: an nn.Module for normalization. :return: an nn.Module for normalization.
""" """
if precision == 16: return nn.GroupNorm(16, channels)
return GroupNorm16(16, channels) # return GroupNorm32(32, channels)
else:
return GroupNorm32(32, channels)
# PyTorch 1.7 has SiLU, but we support PyTorch 1.5. # PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
...@@ -216,9 +214,6 @@ class SiLU(nn.Module): ...@@ -216,9 +214,6 @@ class SiLU(nn.Module):
def forward(self, x): def forward(self, x):
return x * torch.sigmoid(x) return x * torch.sigmoid(x)
class GroupNorm16(nn.GroupNorm):
def forward(self, x):
return super().forward(x.half()).type(x.dtype)
class GroupNorm32(nn.GroupNorm): class GroupNorm32(nn.GroupNorm):
def forward(self, x): def forward(self, x):
......
...@@ -10,24 +10,28 @@ class LitEma(nn.Module): ...@@ -10,24 +10,28 @@ class LitEma(nn.Module):
self.m_name2s_name = {} self.m_name2s_name = {}
self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32)) self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32))
self.register_buffer('num_updates', torch.tensor(0,dtype=torch.int) if use_num_upates self.register_buffer('num_updates', torch.tensor(0, dtype=torch.int) if use_num_upates
else torch.tensor(-1,dtype=torch.int)) else torch.tensor(-1, dtype=torch.int))
for name, p in model.named_parameters(): for name, p in model.named_parameters():
if p.requires_grad: if p.requires_grad:
#remove as '.'-character is not allowed in buffers # remove as '.'-character is not allowed in buffers
s_name = name.replace('.','') s_name = name.replace('.', '')
self.m_name2s_name.update({name:s_name}) self.m_name2s_name.update({name: s_name})
self.register_buffer(s_name,p.clone().detach().data) self.register_buffer(s_name, p.clone().detach().data)
self.collected_params = [] self.collected_params = []
def forward(self,model): def reset_num_updates(self):
del self.num_updates
self.register_buffer('num_updates', torch.tensor(0, dtype=torch.int))
def forward(self, model):
decay = self.decay decay = self.decay
if self.num_updates >= 0: if self.num_updates >= 0:
self.num_updates += 1 self.num_updates += 1
decay = min(self.decay,(1 + self.num_updates) / (10 + self.num_updates)) decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates))
one_minus_decay = 1.0 - decay one_minus_decay = 1.0 - decay
......
from ldm.modules.losses.contperceptual import LPIPSWithDiscriminator
\ No newline at end of file
This diff is collapsed.
import torch
class BaseModel(torch.nn.Module):
def load(self, path):
"""Load model from file.
Args:
path (str): file path
"""
parameters = torch.load(path, map_location=torch.device('cpu'))
if "optimizer" in parameters:
parameters = parameters["model"]
self.load_state_dict(parameters)
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment