Unverified Commit 6c4c6a04 authored by Fazzie-Maqianli's avatar Fazzie-Maqianli Committed by GitHub
Browse files

Merge pull request #2120 from Fazziekey/example/stablediffusion-v2

[example] support stable diffusion v2
parents 5efda697 cea4292a
import torch
import numpy as np
def append_dims(x, target_dims):
"""Appends dimensions to the end of a tensor until it has target_dims dimensions.
From https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/utils.py"""
dims_to_append = target_dims - x.ndim
if dims_to_append < 0:
raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less')
return x[(...,) + (None,) * dims_to_append]
def norm_thresholding(x0, value):
s = append_dims(x0.pow(2).flatten(1).mean(1).sqrt().clamp(min=value), x0.ndim)
return x0 * (value / s)
def spatial_norm_thresholding(x0, value):
# b c h w
s = x0.pow(2).mean(1, keepdim=True).sqrt().clamp(min=value)
return x0 * (value / s)
\ No newline at end of file
......@@ -4,24 +4,17 @@ import torch
import torch.nn.functional as F
from torch import nn, einsum
from einops import rearrange, repeat
from typing import Optional, Any
from ldm.modules.diffusionmodules.util import checkpoint
from torch.utils import checkpoint
try:
from ldm.modules.flash_attention import flash_attention_qkv, flash_attention_q_kv
FlASH_AVAILABLE = True
import xformers
import xformers.ops
XFORMERS_IS_AVAILBLE = True
except:
FlASH_AVAILABLE = False
USE_FLASH = False
def enable_flash_attention():
global USE_FLASH
USE_FLASH = True
if FlASH_AVAILABLE is False:
print("Please install flash attention to activate new attention kernel.\n" +
"Use \'pip install git+https://github.com/HazyResearch/flash-attention.git@c422fee3776eb3ea24e011ef641fd5fbeb212623#egg=flash_attn\'")
XFORMERS_IS_AVAILBLE = False
def exists(val):
......@@ -93,25 +86,6 @@ def Normalize(in_channels):
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
class LinearAttention(nn.Module):
def __init__(self, dim, heads=4, dim_head=32):
super().__init__()
self.heads = heads
hidden_dim = dim_head * heads
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False)
self.to_out = nn.Conv2d(hidden_dim, dim, 1)
def forward(self, x):
b, c, h, w = x.shape
qkv = self.to_qkv(x)
q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3)
k = k.softmax(dim=-1)
context = torch.einsum('bhdn,bhen->bhde', k, v)
out = torch.einsum('bhde,bhdn->bhen', context, q)
out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w)
return self.to_out(out)
class SpatialSelfAttention(nn.Module):
def __init__(self, in_channels):
super().__init__()
......@@ -184,85 +158,111 @@ class CrossAttention(nn.Module):
)
def forward(self, x, context=None, mask=None):
h = self.heads
q = self.to_q(x)
context = default(context, x)
k = self.to_k(context)
v = self.to_v(context)
dim_head = q.shape[-1] / self.heads
if USE_FLASH and FlASH_AVAILABLE and q.dtype in (torch.float16, torch.bfloat16) and \
dim_head <= 128 and (dim_head % 8) == 0:
# print("in flash")
if q.shape[1] == k.shape[1]:
out = self._flash_attention_qkv(q, k, v)
else:
out = self._flash_attention_q_kv(q, k, v)
else:
out = self._native_attention(q, k, v, self.heads, mask)
return self.to_out(out)
def _native_attention(self, q, k, v, h, mask):
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
del q, k
if exists(mask):
mask = rearrange(mask, 'b ... -> b (...)')
max_neg_value = -torch.finfo(sim.dtype).max
mask = repeat(mask, 'b j -> (b h) () j', h=h)
sim.masked_fill_(~mask, max_neg_value)
# attention, what we cannot get enough of
out = sim.softmax(dim=-1)
out = einsum('b i j, b j d -> b i d', out, v)
sim = sim.softmax(dim=-1)
out = einsum('b i j, b j d -> b i d', sim, v)
out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
return out
def _flash_attention_qkv(self, q, k, v):
qkv = torch.stack([q, k, v], dim=2)
b = qkv.shape[0]
n = qkv.shape[1]
qkv = rearrange(qkv, 'b n t (h d) -> (b n) t h d', h=self.heads)
out = flash_attention_qkv(qkv, self.scale, b, n)
out = rearrange(out, '(b n) h d -> b n (h d)', b=b, h=self.heads)
return out
def _flash_attention_q_kv(self, q, k, v):
kv = torch.stack([k, v], dim=2)
b = q.shape[0]
q_seqlen = q.shape[1]
kv_seqlen = kv.shape[1]
q = rearrange(q, 'b n (h d) -> (b n) h d', h=self.heads)
kv = rearrange(kv, 'b n t (h d) -> (b n) t h d', h=self.heads)
out = flash_attention_q_kv(q, kv, self.scale, b, q_seqlen, kv_seqlen)
out = rearrange(out, '(b n) h d -> b n (h d)', b=b, h=self.heads)
return out
return self.to_out(out)
class MemoryEfficientCrossAttention(nn.Module):
# https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
super().__init__()
print(f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using "
f"{heads} heads.")
inner_dim = dim_head * heads
context_dim = default(context_dim, query_dim)
self.heads = heads
self.dim_head = dim_head
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
self.attention_op: Optional[Any] = None
def forward(self, x, context=None, mask=None):
q = self.to_q(x)
context = default(context, x)
k = self.to_k(context)
v = self.to_v(context)
b, _, _ = q.shape
q, k, v = map(
lambda t: t.unsqueeze(3)
.reshape(b, t.shape[1], self.heads, self.dim_head)
.permute(0, 2, 1, 3)
.reshape(b * self.heads, t.shape[1], self.dim_head)
.contiguous(),
(q, k, v),
)
# actually compute the attention, what we cannot get enough of
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op)
if exists(mask):
raise NotImplementedError
out = (
out.unsqueeze(0)
.reshape(b, self.heads, out.shape[1], self.dim_head)
.permute(0, 2, 1, 3)
.reshape(b, out.shape[1], self.heads * self.dim_head)
)
return self.to_out(out)
class BasicTransformerBlock(nn.Module):
def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, use_checkpoint=False):
ATTENTION_MODES = {
"softmax": CrossAttention, # vanilla attention
"softmax-xformers": MemoryEfficientCrossAttention
}
def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True,
disable_self_attn=False):
super().__init__()
self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout) # is a self-attention
attn_mode = "softmax-xformers" if XFORMERS_IS_AVAILBLE else "softmax"
assert attn_mode in self.ATTENTION_MODES
attn_cls = self.ATTENTION_MODES[attn_mode]
self.disable_self_attn = disable_self_attn
self.attn1 = attn_cls(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout,
context_dim=context_dim if self.disable_self_attn else None) # is a self-attention if not self.disable_self_attn
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim,
heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none
self.attn2 = attn_cls(query_dim=dim, context_dim=context_dim,
heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none
self.norm1 = nn.LayerNorm(dim)
self.norm2 = nn.LayerNorm(dim)
self.norm3 = nn.LayerNorm(dim)
self.use_checkpoint = use_checkpoint
self.checkpoint = checkpoint
def forward(self, x, context=None):
if self.use_checkpoint:
return checkpoint(self._forward, x, context)
else:
return self._forward(x, context)
return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint)
def _forward(self, x, context=None):
x = self.attn1(self.norm1(x)) + x
x = self.attn1(self.norm1(x), context=context if self.disable_self_attn else None) + x
x = self.attn2(self.norm2(x), context=context) + x
x = self.ff(self.norm3(x)) + x
return x
class SpatialTransformer(nn.Module):
......@@ -272,43 +272,60 @@ class SpatialTransformer(nn.Module):
and reshape to b, t, d.
Then apply standard transformer action.
Finally, reshape to image
NEW: use_linear for more efficiency instead of the 1x1 convs
"""
def __init__(self, in_channels, n_heads, d_head,
depth=1, dropout=0., context_dim=None, use_checkpoint=False):
depth=1, dropout=0., context_dim=None,
disable_self_attn=False, use_linear=False,
use_checkpoint=True):
super().__init__()
if exists(context_dim) and not isinstance(context_dim, list):
context_dim = [context_dim]
self.in_channels = in_channels
inner_dim = n_heads * d_head
self.norm = Normalize(in_channels)
self.proj_in = nn.Conv2d(in_channels,
inner_dim,
kernel_size=1,
stride=1,
padding=0)
if not use_linear:
self.proj_in = nn.Conv2d(in_channels,
inner_dim,
kernel_size=1,
stride=1,
padding=0)
else:
self.proj_in = nn.Linear(in_channels, inner_dim)
self.transformer_blocks = nn.ModuleList(
[BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim, use_checkpoint=use_checkpoint)
[BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim[d],
disable_self_attn=disable_self_attn, checkpoint=use_checkpoint)
for d in range(depth)]
)
self.proj_out = zero_module(nn.Conv2d(inner_dim,
in_channels,
kernel_size=1,
stride=1,
padding=0))
if not use_linear:
self.proj_out = zero_module(nn.Conv2d(inner_dim,
in_channels,
kernel_size=1,
stride=1,
padding=0))
else:
self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
self.use_linear = use_linear
def forward(self, x, context=None):
# note: if no context is given, cross-attention defaults to self-attention
if not isinstance(context, list):
context = [context]
b, c, h, w = x.shape
x_in = x
x = self.norm(x)
x = self.proj_in(x)
x = rearrange(x, 'b c h w -> b (h w) c')
x = x.contiguous()
for block in self.transformer_blocks:
x = block(x, context=context)
x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)
x = x.contiguous()
x = self.proj_out(x)
return x + x_in
\ No newline at end of file
if not self.use_linear:
x = self.proj_in(x)
x = rearrange(x, 'b c h w -> b (h w) c').contiguous()
if self.use_linear:
x = self.proj_in(x)
for i, block in enumerate(self.transformer_blocks):
x = block(x, context=context[i])
if self.use_linear:
x = self.proj_out(x)
x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous()
if not self.use_linear:
x = self.proj_out(x)
return x + x_in
......@@ -4,9 +4,22 @@ import torch
import torch.nn as nn
import numpy as np
from einops import rearrange
from typing import Optional, Any
from ldm.util import instantiate_from_config
from ldm.modules.attention import LinearAttention
try:
from lightning.pytorch.utilities import rank_zero_info
except:
from pytorch_lightning.utilities import rank_zero_info
from ldm.modules.attention import MemoryEfficientCrossAttention
try:
import xformers
import xformers.ops
XFORMERS_IS_AVAILBLE = True
except:
XFORMERS_IS_AVAILBLE = False
print("No module 'xformers'. Proceeding without it.")
def get_timestep_embedding(timesteps, embedding_dim):
......@@ -141,12 +154,6 @@ class ResnetBlock(nn.Module):
return x+h
class LinAttnBlock(LinearAttention):
"""to match AttnBlock usage"""
def __init__(self, in_channels):
super().__init__(dim=in_channels, heads=1, dim_head=in_channels)
class AttnBlock(nn.Module):
def __init__(self, in_channels):
super().__init__()
......@@ -174,7 +181,6 @@ class AttnBlock(nn.Module):
stride=1,
padding=0)
def forward(self, x):
h_ = x
h_ = self.norm(h_)
......@@ -201,21 +207,100 @@ class AttnBlock(nn.Module):
return x+h_
class MemoryEfficientAttnBlock(nn.Module):
"""
Uses xformers efficient implementation,
see https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
Note: this is a single-head self-attention operation
"""
#
def __init__(self, in_channels):
super().__init__()
self.in_channels = in_channels
self.norm = Normalize(in_channels)
self.q = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.k = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.v = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.proj_out = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.attention_op: Optional[Any] = None
def make_attn(in_channels, attn_type="vanilla"):
assert attn_type in ["vanilla", "linear", "none"], f'attn_type {attn_type} unknown'
print(f"making attention of type '{attn_type}' with {in_channels} in_channels")
def forward(self, x):
h_ = x
h_ = self.norm(h_)
q = self.q(h_)
k = self.k(h_)
v = self.v(h_)
# compute attention
B, C, H, W = q.shape
q, k, v = map(lambda x: rearrange(x, 'b c h w -> b (h w) c'), (q, k, v))
q, k, v = map(
lambda t: t.unsqueeze(3)
.reshape(B, t.shape[1], 1, C)
.permute(0, 2, 1, 3)
.reshape(B * 1, t.shape[1], C)
.contiguous(),
(q, k, v),
)
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op)
out = (
out.unsqueeze(0)
.reshape(B, 1, out.shape[1], C)
.permute(0, 2, 1, 3)
.reshape(B, out.shape[1], C)
)
out = rearrange(out, 'b (h w) c -> b c h w', b=B, h=H, w=W, c=C)
out = self.proj_out(out)
return x+out
class MemoryEfficientCrossAttentionWrapper(MemoryEfficientCrossAttention):
def forward(self, x, context=None, mask=None):
b, c, h, w = x.shape
x = rearrange(x, 'b c h w -> b (h w) c')
out = super().forward(x, context=context, mask=mask)
out = rearrange(out, 'b (h w) c -> b c h w', h=h, w=w, c=c)
return x + out
def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None):
assert attn_type in ["vanilla", "vanilla-xformers", "memory-efficient-cross-attn", "linear", "none"], f'attn_type {attn_type} unknown'
if XFORMERS_IS_AVAILBLE and attn_type == "vanilla":
attn_type = "vanilla-xformers"
rank_zero_info(f"making attention of type '{attn_type}' with {in_channels} in_channels")
if attn_type == "vanilla":
assert attn_kwargs is None
return AttnBlock(in_channels)
elif attn_type == "vanilla-xformers":
rank_zero_info(f"building MemoryEfficientAttnBlock with {in_channels} in_channels...")
return MemoryEfficientAttnBlock(in_channels)
elif type == "memory-efficient-cross-attn":
attn_kwargs["query_dim"] = in_channels
return MemoryEfficientCrossAttentionWrapper(**attn_kwargs)
elif attn_type == "none":
return nn.Identity(in_channels)
else:
return LinAttnBlock(in_channels)
raise NotImplementedError()
class temb_module(nn.Module):
def __init__(self):
super().__init__()
pass
class Model(nn.Module):
def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
......@@ -233,8 +318,7 @@ class Model(nn.Module):
self.use_timestep = use_timestep
if self.use_timestep:
# timestep embedding
# self.temb = nn.Module()
self.temb = temb_module()
self.temb = nn.Module()
self.temb.dense = nn.ModuleList([
torch.nn.Linear(self.ch,
self.temb_ch),
......@@ -265,8 +349,7 @@ class Model(nn.Module):
block_in = block_out
if curr_res in attn_resolutions:
attn.append(make_attn(block_in, attn_type=attn_type))
# down = nn.Module()
down = Down_module()
down = nn.Module()
down.block = block
down.attn = attn
if i_level != self.num_resolutions-1:
......@@ -275,8 +358,7 @@ class Model(nn.Module):
self.down.append(down)
# middle
# self.mid = nn.Module()
self.mid = Mid_module()
self.mid = nn.Module()
self.mid.block_1 = ResnetBlock(in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
......@@ -304,8 +386,7 @@ class Model(nn.Module):
block_in = block_out
if curr_res in attn_resolutions:
attn.append(make_attn(block_in, attn_type=attn_type))
# up = nn.Module()
up = Up_module()
up = nn.Module()
up.block = block
up.attn = attn
if i_level != 0:
......@@ -372,21 +453,6 @@ class Model(nn.Module):
def get_last_layer(self):
return self.conv_out.weight
class Down_module(nn.Module):
def __init__(self):
super().__init__()
pass
class Up_module(nn.Module):
def __init__(self):
super().__init__()
pass
class Mid_module(nn.Module):
def __init__(self):
super().__init__()
pass
class Encoder(nn.Module):
def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
......@@ -426,8 +492,7 @@ class Encoder(nn.Module):
block_in = block_out
if curr_res in attn_resolutions:
attn.append(make_attn(block_in, attn_type=attn_type))
# down = nn.Module()
down = Down_module()
down = nn.Module()
down.block = block
down.attn = attn
if i_level != self.num_resolutions-1:
......@@ -436,8 +501,7 @@ class Encoder(nn.Module):
self.down.append(down)
# middle
# self.mid = nn.Module()
self.mid = Mid_module()
self.mid = nn.Module()
self.mid.block_1 = ResnetBlock(in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
......@@ -505,7 +569,7 @@ class Decoder(nn.Module):
block_in = ch*ch_mult[self.num_resolutions-1]
curr_res = resolution // 2**(self.num_resolutions-1)
self.z_shape = (1,z_channels,curr_res,curr_res)
print("Working with z of shape {} = {} dimensions.".format(
rank_zero_info("Working with z of shape {} = {} dimensions.".format(
self.z_shape, np.prod(self.z_shape)))
# z to block_in
......@@ -516,8 +580,7 @@ class Decoder(nn.Module):
padding=1)
# middle
# self.mid = nn.Module()
self.mid = Mid_module()
self.mid = nn.Module()
self.mid.block_1 = ResnetBlock(in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
......@@ -542,8 +605,7 @@ class Decoder(nn.Module):
block_in = block_out
if curr_res in attn_resolutions:
attn.append(make_attn(block_in, attn_type=attn_type))
# up = nn.Module()
up = Up_module()
up = nn.Module()
up.block = block
up.attn = attn
if i_level != 0:
......@@ -758,7 +820,7 @@ class Upsampler(nn.Module):
assert out_size >= in_size
num_blocks = int(np.log2(out_size//in_size))+1
factor_up = 1.+ (out_size % in_size)
print(f"Building {self.__class__.__name__} with in_size: {in_size} --> out_size {out_size} and factor {factor_up}")
rank_zero_info(f"Building {self.__class__.__name__} with in_size: {in_size} --> out_size {out_size} and factor {factor_up}")
self.rescaler = LatentRescaler(factor=factor_up, in_channels=in_channels, mid_channels=2*in_channels,
out_channels=in_channels)
self.decoder = Decoder(out_ch=out_channels, resolution=out_size, z_channels=in_channels, num_res_blocks=2,
......@@ -777,7 +839,7 @@ class Resize(nn.Module):
self.with_conv = learned
self.mode = mode
if self.with_conv:
print(f"Note: {self.__class__.__name} uses learned downsampling and will ignore the fixed {mode} mode")
rank_zero_info(f"Note: {self.__class__.__name} uses learned downsampling and will ignore the fixed {mode} mode")
raise NotImplementedError()
assert in_channels is not None
# no asymmetric padding in torch conv, must do it ourselves
......@@ -793,70 +855,3 @@ class Resize(nn.Module):
else:
x = torch.nn.functional.interpolate(x, mode=self.mode, align_corners=False, scale_factor=scale_factor)
return x
class FirstStagePostProcessor(nn.Module):
def __init__(self, ch_mult:list, in_channels,
pretrained_model:nn.Module=None,
reshape=False,
n_channels=None,
dropout=0.,
pretrained_config=None):
super().__init__()
if pretrained_config is None:
assert pretrained_model is not None, 'Either "pretrained_model" or "pretrained_config" must not be None'
self.pretrained_model = pretrained_model
else:
assert pretrained_config is not None, 'Either "pretrained_model" or "pretrained_config" must not be None'
self.instantiate_pretrained(pretrained_config)
self.do_reshape = reshape
if n_channels is None:
n_channels = self.pretrained_model.encoder.ch
self.proj_norm = Normalize(in_channels,num_groups=in_channels//2)
self.proj = nn.Conv2d(in_channels,n_channels,kernel_size=3,
stride=1,padding=1)
blocks = []
downs = []
ch_in = n_channels
for m in ch_mult:
blocks.append(ResnetBlock(in_channels=ch_in,out_channels=m*n_channels,dropout=dropout))
ch_in = m * n_channels
downs.append(Downsample(ch_in, with_conv=False))
self.model = nn.ModuleList(blocks)
self.downsampler = nn.ModuleList(downs)
def instantiate_pretrained(self, config):
model = instantiate_from_config(config)
self.pretrained_model = model.eval()
# self.pretrained_model.train = False
for param in self.pretrained_model.parameters():
param.requires_grad = False
@torch.no_grad()
def encode_with_pretrained(self,x):
c = self.pretrained_model.encode(x)
if isinstance(c, DiagonalGaussianDistribution):
c = c.mode()
return c
def forward(self,x):
z_fs = self.encode_with_pretrained(x)
z = self.proj_norm(z_fs)
z = self.proj(z)
z = nonlinearity(z)
for submodel, downmodel in zip(self.model,self.downsampler):
z = submodel(z,temb=None)
z = downmodel(z)
if self.do_reshape:
z = rearrange(z,'b c h w -> b (h w) c')
return z
from abc import abstractmethod
from functools import partial
import math
from typing import Iterable
import numpy as np
import torch
import torch as th
import torch.nn as nn
import torch.nn.functional as F
from torch.utils import checkpoint
from ldm.modules.diffusionmodules.util import (
checkpoint,
conv_nd,
linear,
avg_pool_nd,
......@@ -19,13 +16,11 @@ from ldm.modules.diffusionmodules.util import (
timestep_embedding,
)
from ldm.modules.attention import SpatialTransformer
from ldm.util import exists
# dummy replace
def convert_module_to_f16(x):
# for n,p in x.named_parameter():
# print(f"convert module {n} to_f16")
# p.data = p.data.half()
pass
def convert_module_to_f32(x):
......@@ -251,10 +246,9 @@ class ResBlock(TimestepBlock):
:param emb: an [N x emb_channels] Tensor of timestep embeddings.
:return: an [N x C x ...] Tensor of outputs.
"""
if self.use_checkpoint:
return checkpoint(self._forward, x, emb)
else:
return self._forward(x, emb)
return checkpoint(
self._forward, (x, emb), self.parameters(), self.use_checkpoint
)
def _forward(self, x, emb):
......@@ -317,11 +311,8 @@ class AttentionBlock(nn.Module):
self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
def forward(self, x):
if self.use_checkpoint:
return checkpoint(self._forward, x) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!!
return checkpoint(self._forward, (x,), self.parameters(), True) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!!
#return pt_checkpoint(self._forward, x) # pytorch
else:
return self._forward(x)
def _forward(self, x):
b, c, *spatial = x.shape
......@@ -474,7 +465,10 @@ class UNetModel(nn.Module):
context_dim=None, # custom transformer support
n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
legacy=True,
from_pretrained: str=None
disable_self_attentions=None,
num_attention_blocks=None,
disable_middle_self_attn=False,
use_linear_in_transformer=False,
):
super().__init__()
if use_spatial_transformer:
......@@ -499,7 +493,24 @@ class UNetModel(nn.Module):
self.in_channels = in_channels
self.model_channels = model_channels
self.out_channels = out_channels
self.num_res_blocks = num_res_blocks
if isinstance(num_res_blocks, int):
self.num_res_blocks = len(channel_mult) * [num_res_blocks]
else:
if len(num_res_blocks) != len(channel_mult):
raise ValueError("provide num_res_blocks either as an int (globally constant) or "
"as a list/tuple (per-level) with the same length as channel_mult")
self.num_res_blocks = num_res_blocks
if disable_self_attentions is not None:
# should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
assert len(disable_self_attentions) == len(channel_mult)
if num_attention_blocks is not None:
assert len(num_attention_blocks) == len(self.num_res_blocks)
assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks))))
print(f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
f"attention will still not be set.")
self.attention_resolutions = attention_resolutions
self.dropout = dropout
self.channel_mult = channel_mult
......@@ -520,7 +531,13 @@ class UNetModel(nn.Module):
)
if self.num_classes is not None:
self.label_emb = nn.Embedding(num_classes, time_embed_dim)
if isinstance(self.num_classes, int):
self.label_emb = nn.Embedding(num_classes, time_embed_dim)
elif self.num_classes == "continuous":
print("setting up linear c_adm embedding layer")
self.label_emb = nn.Linear(1, time_embed_dim)
else:
raise ValueError()
self.input_blocks = nn.ModuleList(
[
......@@ -534,7 +551,7 @@ class UNetModel(nn.Module):
ch = model_channels
ds = 1
for level, mult in enumerate(channel_mult):
for _ in range(num_res_blocks):
for nr in range(self.num_res_blocks[level]):
layers = [
ResBlock(
ch,
......@@ -556,17 +573,25 @@ class UNetModel(nn.Module):
if legacy:
#num_heads = 1
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
layers.append(
AttentionBlock(
ch,
use_checkpoint=use_checkpoint,
num_heads=num_heads,
num_head_channels=dim_head,
use_new_attention_order=use_new_attention_order,
) if not use_spatial_transformer else SpatialTransformer(
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim, use_checkpoint=use_checkpoint,
if exists(disable_self_attentions):
disabled_sa = disable_self_attentions[level]
else:
disabled_sa = False
if not exists(num_attention_blocks) or nr < num_attention_blocks[level]:
layers.append(
AttentionBlock(
ch,
use_checkpoint=use_checkpoint,
num_heads=num_heads,
num_head_channels=dim_head,
use_new_attention_order=use_new_attention_order,
) if not use_spatial_transformer else SpatialTransformer(
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
use_checkpoint=use_checkpoint
)
)
)
self.input_blocks.append(TimestepEmbedSequential(*layers))
self._feature_size += ch
input_block_chans.append(ch)
......@@ -618,8 +643,10 @@ class UNetModel(nn.Module):
num_heads=num_heads,
num_head_channels=dim_head,
use_new_attention_order=use_new_attention_order,
) if not use_spatial_transformer else SpatialTransformer(
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim
) if not use_spatial_transformer else SpatialTransformer( # always uses a self-attn
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
use_checkpoint=use_checkpoint
),
ResBlock(
ch,
......@@ -634,7 +661,7 @@ class UNetModel(nn.Module):
self.output_blocks = nn.ModuleList([])
for level, mult in list(enumerate(channel_mult))[::-1]:
for i in range(num_res_blocks + 1):
for i in range(self.num_res_blocks[level] + 1):
ich = input_block_chans.pop()
layers = [
ResBlock(
......@@ -657,18 +684,26 @@ class UNetModel(nn.Module):
if legacy:
#num_heads = 1
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
layers.append(
AttentionBlock(
ch,
use_checkpoint=use_checkpoint,
num_heads=num_heads_upsample,
num_head_channels=dim_head,
use_new_attention_order=use_new_attention_order,
) if not use_spatial_transformer else SpatialTransformer(
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim
if exists(disable_self_attentions):
disabled_sa = disable_self_attentions[level]
else:
disabled_sa = False
if not exists(num_attention_blocks) or i < num_attention_blocks[level]:
layers.append(
AttentionBlock(
ch,
use_checkpoint=use_checkpoint,
num_heads=num_heads_upsample,
num_head_channels=dim_head,
use_new_attention_order=use_new_attention_order,
) if not use_spatial_transformer else SpatialTransformer(
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
use_checkpoint=use_checkpoint
)
)
)
if level and i == num_res_blocks:
if level and i == self.num_res_blocks[level]:
out_ch = ch
layers.append(
ResBlock(
......@@ -699,188 +734,6 @@ class UNetModel(nn.Module):
conv_nd(dims, model_channels, n_embed, 1),
#nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
)
# if use_fp16:
# self.convert_to_fp16()
from diffusers.modeling_utils import load_state_dict
if from_pretrained is not None:
state_dict = load_state_dict(from_pretrained)
self._load_pretrained_model(state_dict)
def _input_blocks_mapping(self, input_dict):
res_dict = {}
for key_, value_ in input_dict.items():
id_0 = int(key_[13])
if "resnets" in key_:
id_1 = int(key_[23])
target_id = 3 * id_0 + 1 + id_1
post_fix = key_[25:].replace('time_emb_proj', 'emb_layers.1')\
.replace('norm1', 'in_layers.0')\
.replace('norm2', 'out_layers.0')\
.replace('conv1', 'in_layers.2')\
.replace('conv2', 'out_layers.3')\
.replace('conv_shortcut', 'skip_connection')
res_dict["input_blocks." + str(target_id) + '.0.' + post_fix] = value_
elif "attentions" in key_:
id_1 = int(key_[26])
target_id = 3 * id_0 + 1 + id_1
post_fix = key_[28:]
res_dict["input_blocks." + str(target_id) + '.1.' + post_fix] = value_
elif "downsamplers" in key_:
post_fix = key_[35:]
target_id = 3 * (id_0 + 1)
res_dict["input_blocks." + str(target_id) + '.0.op.' + post_fix] = value_
return res_dict
def _mid_blocks_mapping(self, mid_dict):
res_dict = {}
for key_, value_ in mid_dict.items():
if "resnets" in key_:
temp_key_ =key_.replace('time_emb_proj', 'emb_layers.1') \
.replace('norm1', 'in_layers.0') \
.replace('norm2', 'out_layers.0') \
.replace('conv1', 'in_layers.2') \
.replace('conv2', 'out_layers.3') \
.replace('conv_shortcut', 'skip_connection')\
.replace('middle_block.resnets.0', 'middle_block.0')\
.replace('middle_block.resnets.1', 'middle_block.2')
res_dict[temp_key_] = value_
elif "attentions" in key_:
res_dict[key_.replace('attentions.0', '1')] = value_
return res_dict
def _other_blocks_mapping(self, other_dict):
res_dict = {}
for key_, value_ in other_dict.items():
tmp_key = key_.replace('conv_in', 'input_blocks.0.0')\
.replace('time_embedding.linear_1', 'time_embed.0')\
.replace('time_embedding.linear_2', 'time_embed.2')\
.replace('conv_norm_out', 'out.0')\
.replace('conv_out', 'out.2')
res_dict[tmp_key] = value_
return res_dict
def _output_blocks_mapping(self, output_dict):
res_dict = {}
for key_, value_ in output_dict.items():
id_0 = int(key_[14])
if "resnets" in key_:
id_1 = int(key_[24])
target_id = 3 * id_0 + id_1
post_fix = key_[26:].replace('time_emb_proj', 'emb_layers.1') \
.replace('norm1', 'in_layers.0') \
.replace('norm2', 'out_layers.0') \
.replace('conv1', 'in_layers.2') \
.replace('conv2', 'out_layers.3') \
.replace('conv_shortcut', 'skip_connection')
res_dict["output_blocks." + str(target_id) + '.0.' + post_fix] = value_
elif "attentions" in key_:
id_1 = int(key_[27])
target_id = 3 * id_0 + id_1
post_fix = key_[29:]
res_dict["output_blocks." + str(target_id) + '.1.' + post_fix] = value_
elif "upsamplers" in key_:
post_fix = key_[34:]
target_id = 3 * (id_0 + 1) - 1
mid_str = '.2.conv.' if target_id != 2 else '.1.conv.'
res_dict["output_blocks." + str(target_id) + mid_str + post_fix] = value_
return res_dict
def _state_key_mapping(self, state_dict: dict):
import re
res_dict = {}
input_dict = {}
mid_dict = {}
output_dict = {}
other_dict = {}
for key_, value_ in state_dict.items():
if "down_blocks" in key_:
input_dict[key_.replace('down_blocks', 'input_blocks')] = value_
elif "up_blocks" in key_:
output_dict[key_.replace('up_blocks', 'output_blocks')] = value_
elif "mid_block" in key_:
mid_dict[key_.replace('mid_block', 'middle_block')] = value_
else:
other_dict[key_] = value_
input_dict = self._input_blocks_mapping(input_dict)
output_dict = self._output_blocks_mapping(output_dict)
mid_dict = self._mid_blocks_mapping(mid_dict)
other_dict = self._other_blocks_mapping(other_dict)
# key_list = state_dict.keys()
# key_str = " ".join(key_list)
# for key_, val_ in state_dict.items():
# key_ = key_.replace("down_blocks", "input_blocks")\
# .replace("up_blocks", 'output_blocks')
# res_dict[key_] = val_
res_dict.update(input_dict)
res_dict.update(output_dict)
res_dict.update(mid_dict)
res_dict.update(other_dict)
return res_dict
def _load_pretrained_model(self, state_dict, ignore_mismatched_sizes=False):
state_dict = self._state_key_mapping(state_dict)
model_state_dict = self.state_dict()
loaded_keys = [k for k in state_dict.keys()]
expected_keys = list(model_state_dict.keys())
original_loaded_keys = loaded_keys
missing_keys = list(set(expected_keys) - set(loaded_keys))
unexpected_keys = list(set(loaded_keys) - set(expected_keys))
def _find_mismatched_keys(
state_dict,
model_state_dict,
loaded_keys,
ignore_mismatched_sizes,
):
mismatched_keys = []
if ignore_mismatched_sizes:
for checkpoint_key in loaded_keys:
model_key = checkpoint_key
if (
model_key in model_state_dict
and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape
):
mismatched_keys.append(
(checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape)
)
del state_dict[checkpoint_key]
return mismatched_keys
if state_dict is not None:
# Whole checkpoint
mismatched_keys = _find_mismatched_keys(
state_dict,
model_state_dict,
original_loaded_keys,
ignore_mismatched_sizes,
)
error_msgs = self._load_state_dict_into_model(state_dict)
return missing_keys, unexpected_keys, mismatched_keys, error_msgs
def _load_state_dict_into_model(self, state_dict):
# Convert old format to new format if needed from a PyTorch state_dict
# copy state_dict so _load_from_state_dict can modify it
state_dict = state_dict.copy()
error_msgs = []
# PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
# so we need to apply the function recursively.
def load(module: torch.nn.Module, prefix=""):
args = (state_dict, prefix, {}, True, [], [], error_msgs)
module._load_from_state_dict(*args)
for name, child in module._modules.items():
if child is not None:
load(child, prefix + name + ".")
load(self)
return error_msgs
def convert_to_fp16(self):
"""
......@@ -912,10 +765,11 @@ class UNetModel(nn.Module):
), "must specify y if and only if the model is class-conditional"
hs = []
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
t_emb = t_emb.type(self.dtype)
emb = self.time_embed(t_emb)
if self.num_classes is not None:
assert y.shape == (x.shape[0],)
assert y.shape[0] == x.shape[0]
emb = emb + self.label_emb(y)
h = x.type(self.dtype)
......@@ -926,227 +780,8 @@ class UNetModel(nn.Module):
for module in self.output_blocks:
h = th.cat([h, hs.pop()], dim=1)
h = module(h, emb, context)
h = h.type(self.dtype)
h = h.type(x.dtype)
if self.predict_codebook_ids:
return self.id_predictor(h)
else:
return self.out(h)
class EncoderUNetModel(nn.Module):
"""
The half UNet model with attention and timestep embedding.
For usage, see UNet.
"""
def __init__(
self,
image_size,
in_channels,
model_channels,
out_channels,
num_res_blocks,
attention_resolutions,
dropout=0,
channel_mult=(1, 2, 4, 8),
conv_resample=True,
dims=2,
use_checkpoint=False,
use_fp16=False,
num_heads=1,
num_head_channels=-1,
num_heads_upsample=-1,
use_scale_shift_norm=False,
resblock_updown=False,
use_new_attention_order=False,
pool="adaptive",
*args,
**kwargs
):
super().__init__()
if num_heads_upsample == -1:
num_heads_upsample = num_heads
self.in_channels = in_channels
self.model_channels = model_channels
self.out_channels = out_channels
self.num_res_blocks = num_res_blocks
self.attention_resolutions = attention_resolutions
self.dropout = dropout
self.channel_mult = channel_mult
self.conv_resample = conv_resample
self.use_checkpoint = use_checkpoint
self.dtype = th.float16 if use_fp16 else th.float32
self.num_heads = num_heads
self.num_head_channels = num_head_channels
self.num_heads_upsample = num_heads_upsample
time_embed_dim = model_channels * 4
self.time_embed = nn.Sequential(
linear(model_channels, time_embed_dim),
nn.SiLU(),
linear(time_embed_dim, time_embed_dim),
)
self.input_blocks = nn.ModuleList(
[
TimestepEmbedSequential(
conv_nd(dims, in_channels, model_channels, 3, padding=1)
)
]
)
self._feature_size = model_channels
input_block_chans = [model_channels]
ch = model_channels
ds = 1
for level, mult in enumerate(channel_mult):
for _ in range(num_res_blocks):
layers = [
ResBlock(
ch,
time_embed_dim,
dropout,
out_channels=mult * model_channels,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
)
]
ch = mult * model_channels
if ds in attention_resolutions:
layers.append(
AttentionBlock(
ch,
use_checkpoint=use_checkpoint,
num_heads=num_heads,
num_head_channels=num_head_channels,
use_new_attention_order=use_new_attention_order,
)
)
self.input_blocks.append(TimestepEmbedSequential(*layers))
self._feature_size += ch
input_block_chans.append(ch)
if level != len(channel_mult) - 1:
out_ch = ch
self.input_blocks.append(
TimestepEmbedSequential(
ResBlock(
ch,
time_embed_dim,
dropout,
out_channels=out_ch,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
down=True,
)
if resblock_updown
else Downsample(
ch, conv_resample, dims=dims, out_channels=out_ch
)
)
)
ch = out_ch
input_block_chans.append(ch)
ds *= 2
self._feature_size += ch
self.middle_block = TimestepEmbedSequential(
ResBlock(
ch,
time_embed_dim,
dropout,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
),
AttentionBlock(
ch,
use_checkpoint=use_checkpoint,
num_heads=num_heads,
num_head_channels=num_head_channels,
use_new_attention_order=use_new_attention_order,
),
ResBlock(
ch,
time_embed_dim,
dropout,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
),
)
self._feature_size += ch
self.pool = pool
if pool == "adaptive":
self.out = nn.Sequential(
normalization(ch),
nn.SiLU(),
nn.AdaptiveAvgPool2d((1, 1)),
zero_module(conv_nd(dims, ch, out_channels, 1)),
nn.Flatten(),
)
elif pool == "attention":
assert num_head_channels != -1
self.out = nn.Sequential(
normalization(ch),
nn.SiLU(),
AttentionPool2d(
(image_size // ds), ch, num_head_channels, out_channels
),
)
elif pool == "spatial":
self.out = nn.Sequential(
nn.Linear(self._feature_size, 2048),
nn.ReLU(),
nn.Linear(2048, self.out_channels),
)
elif pool == "spatial_v2":
self.out = nn.Sequential(
nn.Linear(self._feature_size, 2048),
normalization(2048),
nn.SiLU(),
nn.Linear(2048, self.out_channels),
)
else:
raise NotImplementedError(f"Unexpected {pool} pooling")
def convert_to_fp16(self):
"""
Convert the torso of the model to float16.
"""
self.input_blocks.apply(convert_module_to_f16)
self.middle_block.apply(convert_module_to_f16)
def convert_to_fp32(self):
"""
Convert the torso of the model to float32.
"""
self.input_blocks.apply(convert_module_to_f32)
self.middle_block.apply(convert_module_to_f32)
def forward(self, x, timesteps):
"""
Apply the model to an input batch.
:param x: an [N x C x ...] Tensor of inputs.
:param timesteps: a 1-D batch of timesteps.
:return: an [N x K] Tensor of outputs.
"""
emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
results = []
h = x.type(self.dtype)
for module in self.input_blocks:
h = module(h, emb)
if self.pool.startswith("spatial"):
results.append(h.type(x.dtype).mean(dim=(2, 3)))
h = self.middle_block(h, emb)
if self.pool.startswith("spatial"):
results.append(h.type(x.dtype).mean(dim=(2, 3)))
h = th.cat(results, axis=-1)
return self.out(h)
else:
h = h.type(self.dtype)
return self.out(h)
import torch
import torch.nn as nn
import numpy as np
from functools import partial
from ldm.modules.diffusionmodules.util import extract_into_tensor, make_beta_schedule
from ldm.util import default
class AbstractLowScaleModel(nn.Module):
# for concatenating a downsampled image to the latent representation
def __init__(self, noise_schedule_config=None):
super(AbstractLowScaleModel, self).__init__()
if noise_schedule_config is not None:
self.register_schedule(**noise_schedule_config)
def register_schedule(self, beta_schedule="linear", timesteps=1000,
linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end,
cosine_s=cosine_s)
alphas = 1. - betas
alphas_cumprod = np.cumprod(alphas, axis=0)
alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
timesteps, = betas.shape
self.num_timesteps = int(timesteps)
self.linear_start = linear_start
self.linear_end = linear_end
assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep'
to_torch = partial(torch.tensor, dtype=torch.float32)
self.register_buffer('betas', to_torch(betas))
self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev))
# calculations for diffusion q(x_t | x_{t-1}) and others
self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod)))
self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod)))
self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1)))
def q_sample(self, x_start, t, noise=None):
noise = default(noise, lambda: torch.randn_like(x_start))
return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise)
def forward(self, x):
return x, None
def decode(self, x):
return x
class SimpleImageConcat(AbstractLowScaleModel):
# no noise level conditioning
def __init__(self):
super(SimpleImageConcat, self).__init__(noise_schedule_config=None)
self.max_noise_level = 0
def forward(self, x):
# fix to constant noise level
return x, torch.zeros(x.shape[0], device=x.device).long()
class ImageConcatWithNoiseAugmentation(AbstractLowScaleModel):
def __init__(self, noise_schedule_config, max_noise_level=1000, to_cuda=False):
super().__init__(noise_schedule_config=noise_schedule_config)
self.max_noise_level = max_noise_level
def forward(self, x, noise_level=None):
if noise_level is None:
noise_level = torch.randint(0, self.max_noise_level, (x.shape[0],), device=x.device).long()
else:
assert isinstance(noise_level, torch.Tensor)
z = self.q_sample(x, noise_level)
return z, noise_level
......@@ -122,7 +122,9 @@ class CheckpointFunction(torch.autograd.Function):
ctx.run_function = run_function
ctx.input_tensors = list(args[:length])
ctx.input_params = list(args[length:])
ctx.gpu_autocast_kwargs = {"enabled": torch.is_autocast_enabled(),
"dtype": torch.get_autocast_gpu_dtype(),
"cache_enabled": torch.is_autocast_cache_enabled()}
with torch.no_grad():
output_tensors = ctx.run_function(*ctx.input_tensors)
return output_tensors
......@@ -130,7 +132,8 @@ class CheckpointFunction(torch.autograd.Function):
@staticmethod
def backward(ctx, *output_grads):
ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
with torch.enable_grad():
with torch.enable_grad(), \
torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs):
# Fixes a bug where the first op in run_function modifies the
# Tensor storage in place, which is not allowed for detach()'d
# Tensors.
......@@ -148,7 +151,7 @@ class CheckpointFunction(torch.autograd.Function):
return (None, None) + input_grads
def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False, use_fp16=True):
def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
"""
Create sinusoidal timestep embeddings.
:param timesteps: a 1-D Tensor of N indices, one per batch element.
......@@ -168,10 +171,7 @@ def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False, use_
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
else:
embedding = repeat(timesteps, 'b -> b d', d=dim)
if use_fp16:
return embedding.half()
else:
return embedding
return embedding
def zero_module(module):
......@@ -199,16 +199,14 @@ def mean_flat(tensor):
return tensor.mean(dim=list(range(1, len(tensor.shape))))
def normalization(channels, precision=16):
def normalization(channels):
"""
Make a standard normalization layer.
:param channels: number of input channels.
:return: an nn.Module for normalization.
"""
if precision == 16:
return GroupNorm16(16, channels)
else:
return GroupNorm32(32, channels)
return nn.GroupNorm(16, channels)
# return GroupNorm32(32, channels)
# PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
......@@ -216,9 +214,6 @@ class SiLU(nn.Module):
def forward(self, x):
return x * torch.sigmoid(x)
class GroupNorm16(nn.GroupNorm):
def forward(self, x):
return super().forward(x.half()).type(x.dtype)
class GroupNorm32(nn.GroupNorm):
def forward(self, x):
......
......@@ -10,24 +10,28 @@ class LitEma(nn.Module):
self.m_name2s_name = {}
self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32))
self.register_buffer('num_updates', torch.tensor(0,dtype=torch.int) if use_num_upates
else torch.tensor(-1,dtype=torch.int))
self.register_buffer('num_updates', torch.tensor(0, dtype=torch.int) if use_num_upates
else torch.tensor(-1, dtype=torch.int))
for name, p in model.named_parameters():
if p.requires_grad:
#remove as '.'-character is not allowed in buffers
s_name = name.replace('.','')
self.m_name2s_name.update({name:s_name})
self.register_buffer(s_name,p.clone().detach().data)
# remove as '.'-character is not allowed in buffers
s_name = name.replace('.', '')
self.m_name2s_name.update({name: s_name})
self.register_buffer(s_name, p.clone().detach().data)
self.collected_params = []
def forward(self,model):
def reset_num_updates(self):
del self.num_updates
self.register_buffer('num_updates', torch.tensor(0, dtype=torch.int))
def forward(self, model):
decay = self.decay
if self.num_updates >= 0:
self.num_updates += 1
decay = min(self.decay,(1 + self.num_updates) / (10 + self.num_updates))
decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates))
one_minus_decay = 1.0 - decay
......
import types
import torch
import torch.nn as nn
from functools import partial
import clip
from einops import rearrange, repeat
from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextConfig
import kornia
from transformers.models.clip.modeling_clip import CLIPTextTransformer
from torch.utils.checkpoint import checkpoint
from transformers import T5Tokenizer, T5EncoderModel, CLIPTokenizer, CLIPTextModel
from ldm.modules.x_transformer import Encoder, TransformerWrapper # TODO: can we directly rely on lucidrains code and simply add this as a reuirement? --> test
import open_clip
from ldm.util import default, count_params
class AbstractEncoder(nn.Module):
......@@ -20,189 +16,149 @@ class AbstractEncoder(nn.Module):
raise NotImplementedError
class IdentityEncoder(AbstractEncoder):
def encode(self, x):
return x
class ClassEmbedder(nn.Module):
def __init__(self, embed_dim, n_classes=1000, key='class'):
def __init__(self, embed_dim, n_classes=1000, key='class', ucg_rate=0.1):
super().__init__()
self.key = key
self.embedding = nn.Embedding(n_classes, embed_dim)
self.n_classes = n_classes
self.ucg_rate = ucg_rate
def forward(self, batch, key=None):
def forward(self, batch, key=None, disable_dropout=False):
if key is None:
key = self.key
# this is for use in crossattn
c = batch[key][:, None]
if self.ucg_rate > 0. and not disable_dropout:
mask = 1. - torch.bernoulli(torch.ones_like(c) * self.ucg_rate)
c = mask * c + (1-mask) * torch.ones_like(c)*(self.n_classes-1)
c = c.long()
c = self.embedding(c)
return c
def get_unconditional_conditioning(self, bs, device="cuda"):
uc_class = self.n_classes - 1 # 1000 classes --> 0 ... 999, one extra class for ucg (class 1000)
uc = torch.ones((bs,), device=device) * uc_class
uc = {self.key: uc}
return uc
class TransformerEmbedder(AbstractEncoder):
"""Some transformer encoder layers"""
def __init__(self, n_embed, n_layer, vocab_size, max_seq_len=77, device="cuda"):
super().__init__()
self.device = device
self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len,
attn_layers=Encoder(dim=n_embed, depth=n_layer))
def forward(self, tokens):
tokens = tokens.to(self.device) # meh
z = self.transformer(tokens, return_embeddings=True)
return z
def encode(self, x):
return self(x)
def disabled_train(self, mode=True):
"""Overwrite model.train with this function to make sure train/eval mode
does not change anymore."""
return self
class BERTTokenizer(AbstractEncoder):
""" Uses a pretrained BERT tokenizer by huggingface. Vocab size: 30522 (?)"""
def __init__(self, device="cuda", vq_interface=True, max_length=77):
class FrozenT5Embedder(AbstractEncoder):
"""Uses the T5 transformer encoder for text"""
def __init__(self, version="google/t5-v1_1-large", device="cuda", max_length=77, freeze=True): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl
super().__init__()
from transformers import BertTokenizerFast # TODO: add to reuquirements
self.tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
self.tokenizer = T5Tokenizer.from_pretrained(version)
self.transformer = T5EncoderModel.from_pretrained(version)
self.device = device
self.vq_interface = vq_interface
self.max_length = max_length
self.max_length = max_length # TODO: typical value?
if freeze:
self.freeze()
def freeze(self):
self.transformer = self.transformer.eval()
#self.train = disabled_train
for param in self.parameters():
param.requires_grad = False
def forward(self, text):
batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
tokens = batch_encoding["input_ids"].to(self.device)
return tokens
@torch.no_grad()
def encode(self, text):
tokens = self(text)
if not self.vq_interface:
return tokens
return None, None, [None, None, tokens]
def decode(self, text):
return text
class BERTEmbedder(AbstractEncoder):
"""Uses the BERT tokenizr model and add some transformer encoder layers"""
def __init__(self, n_embed, n_layer, vocab_size=30522, max_seq_len=77,
device="cuda",use_tokenizer=True, embedding_dropout=0.0):
super().__init__()
self.use_tknz_fn = use_tokenizer
if self.use_tknz_fn:
self.tknz_fn = BERTTokenizer(vq_interface=False, max_length=max_seq_len)
self.device = device
self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len,
attn_layers=Encoder(dim=n_embed, depth=n_layer),
emb_dropout=embedding_dropout)
outputs = self.transformer(input_ids=tokens)
def forward(self, text):
if self.use_tknz_fn:
tokens = self.tknz_fn(text)#.to(self.device)
else:
tokens = text
z = self.transformer(tokens, return_embeddings=True)
z = outputs.last_hidden_state
return z
def encode(self, text):
# output of length 77
return self(text)
class SpatialRescaler(nn.Module):
def __init__(self,
n_stages=1,
method='bilinear',
multiplier=0.5,
in_channels=3,
out_channels=None,
bias=False):
super().__init__()
self.n_stages = n_stages
assert self.n_stages >= 0
assert method in ['nearest','linear','bilinear','trilinear','bicubic','area']
self.multiplier = multiplier
self.interpolator = partial(torch.nn.functional.interpolate, mode=method)
self.remap_output = out_channels is not None
if self.remap_output:
print(f'Spatial Rescaler mapping from {in_channels} to {out_channels} channels after resizing.')
self.channel_mapper = nn.Conv2d(in_channels,out_channels,1,bias=bias)
def forward(self,x):
for stage in range(self.n_stages):
x = self.interpolator(x, scale_factor=self.multiplier)
if self.remap_output:
x = self.channel_mapper(x)
return x
def encode(self, x):
return self(x)
class CLIPTextModelZero(CLIPTextModel):
config_class = CLIPTextConfig
def __init__(self, config: CLIPTextConfig):
super().__init__(config)
self.text_model = CLIPTextTransformerZero(config)
class CLIPTextTransformerZero(CLIPTextTransformer):
def _build_causal_attention_mask(self, bsz, seq_len):
# lazily create causal attention mask, with full attention between the vision tokens
# pytorch uses additive attention mask; fill with -inf
mask = torch.empty(bsz, seq_len, seq_len)
mask.fill_(float("-inf"))
mask.triu_(1) # zero out the lower diagonal
mask = mask.unsqueeze(1) # expand mask
return mask.half()
class FrozenCLIPEmbedder(AbstractEncoder):
"""Uses the CLIP transformer encoder for text (from Hugging Face)"""
def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77, use_fp16=True):
"""Uses the CLIP transformer encoder for text (from huggingface)"""
LAYERS = [
"last",
"pooled",
"hidden"
]
def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77,
freeze=True, layer="last", layer_idx=None): # clip-vit-base-patch32
super().__init__()
assert layer in self.LAYERS
self.tokenizer = CLIPTokenizer.from_pretrained(version)
if use_fp16:
self.transformer = CLIPTextModelZero.from_pretrained(version)
else:
self.transformer = CLIPTextModel.from_pretrained(version)
# print(self.transformer.modules())
# print("check model dtyoe: {}, {}".format(self.tokenizer.dtype, self.transformer.dtype))
self.transformer = CLIPTextModel.from_pretrained(version)
self.device = device
self.max_length = max_length
self.freeze()
if freeze:
self.freeze()
self.layer = layer
self.layer_idx = layer_idx
if layer == "hidden":
assert layer_idx is not None
assert 0 <= abs(layer_idx) <= 12
def freeze(self):
self.transformer = self.transformer.eval()
#self.train = disabled_train
for param in self.parameters():
param.requires_grad = False
def forward(self, text):
batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
# tokens = batch_encoding["input_ids"].to(self.device)
tokens = batch_encoding["input_ids"].to(self.device)
# print("token type: {}".format(tokens.dtype))
outputs = self.transformer(input_ids=tokens)
z = outputs.last_hidden_state
outputs = self.transformer(input_ids=tokens, output_hidden_states=self.layer=="hidden")
if self.layer == "last":
z = outputs.last_hidden_state
elif self.layer == "pooled":
z = outputs.pooler_output[:, None, :]
else:
z = outputs.hidden_states[self.layer_idx]
return z
def encode(self, text):
return self(text)
class FrozenCLIPTextEmbedder(nn.Module):
class FrozenOpenCLIPEmbedder(AbstractEncoder):
"""
Uses the CLIP transformer encoder for text.
Uses the OpenCLIP transformer encoder for text
"""
def __init__(self, version='ViT-L/14', device="cuda", max_length=77, n_repeat=1, normalize=True):
LAYERS = [
#"pooled",
"last",
"penultimate"
]
def __init__(self, arch="ViT-H-14", version="laion2b_s32b_b79k", device="cuda", max_length=77,
freeze=True, layer="last"):
super().__init__()
self.model, _ = clip.load(version, jit=False, device="cpu")
assert layer in self.LAYERS
model, _, _ = open_clip.create_model_and_transforms(arch, device=torch.device('cpu'), pretrained=version)
del model.visual
self.model = model
self.device = device
self.max_length = max_length
self.n_repeat = n_repeat
self.normalize = normalize
if freeze:
self.freeze()
self.layer = layer
if self.layer == "last":
self.layer_idx = 0
elif self.layer == "penultimate":
self.layer_idx = 1
else:
raise NotImplementedError()
def freeze(self):
self.model = self.model.eval()
......@@ -210,55 +166,48 @@ class FrozenCLIPTextEmbedder(nn.Module):
param.requires_grad = False
def forward(self, text):
tokens = clip.tokenize(text).to(self.device)
z = self.model.encode_text(tokens)
if self.normalize:
z = z / torch.linalg.norm(z, dim=1, keepdim=True)
tokens = open_clip.tokenize(text)
z = self.encode_with_transformer(tokens.to(self.device))
return z
def encode(self, text):
z = self(text)
if z.ndim==2:
z = z[:, None, :]
z = repeat(z, 'b 1 d -> b k d', k=self.n_repeat)
return z
def encode_with_transformer(self, text):
x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model]
x = x + self.model.positional_embedding
x = x.permute(1, 0, 2) # NLD -> LND
x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask)
x = x.permute(1, 0, 2) # LND -> NLD
x = self.model.ln_final(x)
return x
def text_transformer_forward(self, x: torch.Tensor, attn_mask = None):
for i, r in enumerate(self.model.transformer.resblocks):
if i == len(self.model.transformer.resblocks) - self.layer_idx:
break
if self.model.transformer.grad_checkpointing and not torch.jit.is_scripting():
x = checkpoint(r, x, attn_mask)
else:
x = r(x, attn_mask=attn_mask)
return x
class FrozenClipImageEmbedder(nn.Module):
"""
Uses the CLIP image encoder.
"""
def __init__(
self,
model,
jit=False,
device='cuda' if torch.cuda.is_available() else 'cpu',
antialias=False,
):
super().__init__()
self.model, _ = clip.load(name=model, device=device, jit=jit)
def encode(self, text):
return self(text)
self.antialias = antialias
self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False)
self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False)
class FrozenCLIPT5Encoder(AbstractEncoder):
def __init__(self, clip_version="openai/clip-vit-large-patch14", t5_version="google/t5-v1_1-xl", device="cuda",
clip_max_length=77, t5_max_length=77):
super().__init__()
self.clip_encoder = FrozenCLIPEmbedder(clip_version, device, max_length=clip_max_length)
self.t5_encoder = FrozenT5Embedder(t5_version, device, max_length=t5_max_length)
print(f"{self.clip_encoder.__class__.__name__} has {count_params(self.clip_encoder)*1.e-6:.2f} M parameters, "
f"{self.t5_encoder.__class__.__name__} comes with {count_params(self.t5_encoder)*1.e-6:.2f} M params.")
def preprocess(self, x):
# normalize to [0,1]
x = kornia.geometry.resize(x, (224, 224),
interpolation='bicubic',align_corners=True,
antialias=self.antialias)
x = (x + 1.) / 2.
# renormalize according to clip
x = kornia.enhance.normalize(x, self.mean, self.std)
return x
def encode(self, text):
return self(text)
def forward(self, x):
# x is assumed to be in range [-1,1]
return self.model.encode_image(self.preprocess(x))
def forward(self, text):
clip_z = self.clip_encoder.encode(text)
t5_z = self.t5_encoder.encode(text)
return [clip_z, t5_z]
if __name__ == "__main__":
from ldm.util import count_params
model = FrozenCLIPEmbedder()
count_params(model, verbose=True)
\ No newline at end of file
"""
Fused Attention
===============
This is a Triton implementation of the Flash Attention algorithm
(see: Dao et al., https://arxiv.org/pdf/2205.14135v2.pdf; Rabe and Staats https://arxiv.org/pdf/2112.05682v2.pdf; Triton https://github.com/openai/triton)
"""
import torch
try:
from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func, flash_attn_unpadded_kvpacked_func
except ImportError:
raise ImportError('please install flash_attn from https://github.com/HazyResearch/flash-attention')
def flash_attention_qkv(qkv, sm_scale, batch_size, seq_len):
"""
Arguments:
qkv: (batch*seq, 3, nheads, headdim)
batch_size: int.
seq_len: int.
sm_scale: float. The scaling of QK^T before applying softmax.
Return:
out: (total, nheads, headdim).
"""
max_s = seq_len
cu_seqlens = torch.arange(0, (batch_size + 1) * seq_len, step=seq_len, dtype=torch.int32,
device=qkv.device)
out = flash_attn_unpadded_qkvpacked_func(
qkv, cu_seqlens, max_s, 0.0,
softmax_scale=sm_scale, causal=False
)
return out
def flash_attention_q_kv(q, kv, sm_scale, batch_size, q_seqlen, kv_seqlen):
"""
Arguments:
q: (batch*seq, nheads, headdim)
kv: (batch*seq, 2, nheads, headdim)
batch_size: int.
seq_len: int.
sm_scale: float. The scaling of QK^T before applying softmax.
Return:
out: (total, nheads, headdim).
"""
cu_seqlens_q = torch.arange(0, (batch_size + 1) * q_seqlen, step=q_seqlen, dtype=torch.int32, device=q.device)
cu_seqlens_k = torch.arange(0, (batch_size + 1) * kv_seqlen, step=kv_seqlen, dtype=torch.int32, device=kv.device)
out = flash_attn_unpadded_kvpacked_func(q, kv, cu_seqlens_q, cu_seqlens_k, q_seqlen, kv_seqlen, 0.0, sm_scale)
return out
......@@ -25,7 +25,6 @@ import ldm.modules.image_degradation.utils_image as util
# --------------------------------------------
"""
def modcrop_np(img, sf):
'''
Args:
......@@ -254,7 +253,7 @@ def srmd_degradation(x, k, sf=3):
year={2018}
}
'''
x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') # 'nearest' | 'mirror'
x = ndimage.convolve(x, np.expand_dims(k, axis=2), mode='wrap') # 'nearest' | 'mirror'
x = bicubic_degradation(x, sf=sf)
return x
......@@ -277,7 +276,7 @@ def dpsr_degradation(x, k, sf=3):
}
'''
x = bicubic_degradation(x, sf=sf)
x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap')
x = ndimage.convolve(x, np.expand_dims(k, axis=2), mode='wrap')
return x
......@@ -290,7 +289,7 @@ def classical_degradation(x, k, sf=3):
Return:
downsampled LR image
'''
x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap')
x = ndimage.convolve(x, np.expand_dims(k, axis=2), mode='wrap')
# x = filters.correlate(x, np.expand_dims(np.flip(k), axis=2))
st = 0
return x[st::sf, st::sf, ...]
......@@ -335,7 +334,7 @@ def add_blur(img, sf=4):
k = anisotropic_Gaussian(ksize=random.randint(2, 11) + 3, theta=random.random() * np.pi, l1=l1, l2=l2)
else:
k = fspecial('gaussian', random.randint(2, 4) + 3, wd * random.random())
img = ndimage.filters.convolve(img, np.expand_dims(k, axis=2), mode='mirror')
img = ndimage.convolve(img, np.expand_dims(k, axis=2), mode='mirror')
return img
......@@ -497,7 +496,7 @@ def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None):
k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf))
k_shifted = shift_pixel(k, sf)
k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel
img = ndimage.filters.convolve(img, np.expand_dims(k_shifted, axis=2), mode='mirror')
img = ndimage.convolve(img, np.expand_dims(k_shifted, axis=2), mode='mirror')
img = img[0::sf, 0::sf, ...] # nearest downsampling
img = np.clip(img, 0.0, 1.0)
......@@ -531,7 +530,7 @@ def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None):
# todo no isp_model?
def degradation_bsrgan_variant(image, sf=4, isp_model=None):
def degradation_bsrgan_variant(image, sf=4, isp_model=None, up=False):
"""
This is the degradation model of BSRGAN from the paper
"Designing a Practical Degradation Model for Deep Blind Image Super-Resolution"
......@@ -589,7 +588,7 @@ def degradation_bsrgan_variant(image, sf=4, isp_model=None):
k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf))
k_shifted = shift_pixel(k, sf)
k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel
image = ndimage.filters.convolve(image, np.expand_dims(k_shifted, axis=2), mode='mirror')
image = ndimage.convolve(image, np.expand_dims(k_shifted, axis=2), mode='mirror')
image = image[0::sf, 0::sf, ...] # nearest downsampling
image = np.clip(image, 0.0, 1.0)
......@@ -617,6 +616,8 @@ def degradation_bsrgan_variant(image, sf=4, isp_model=None):
# add final JPEG compression noise
image = add_JPEG_noise(image)
image = util.single2uint(image)
if up:
image = cv2.resize(image, (w1, h1), interpolation=cv2.INTER_CUBIC) # todo: random, as above? want to condition on it then
example = {"image": image}
return example
......
from ldm.modules.losses.contperceptual import LPIPSWithDiscriminator
\ No newline at end of file
import torch
import torch.nn as nn
from taming.modules.losses.vqperceptual import * # TODO: taming dependency yes/no?
class LPIPSWithDiscriminator(nn.Module):
def __init__(self, disc_start, logvar_init=0.0, kl_weight=1.0, pixelloss_weight=1.0,
disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0,
perceptual_weight=1.0, use_actnorm=False, disc_conditional=False,
disc_loss="hinge"):
super().__init__()
assert disc_loss in ["hinge", "vanilla"]
self.kl_weight = kl_weight
self.pixel_weight = pixelloss_weight
self.perceptual_loss = LPIPS().eval()
self.perceptual_weight = perceptual_weight
# output log variance
self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init)
self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels,
n_layers=disc_num_layers,
use_actnorm=use_actnorm
).apply(weights_init)
self.discriminator_iter_start = disc_start
self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss
self.disc_factor = disc_factor
self.discriminator_weight = disc_weight
self.disc_conditional = disc_conditional
def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
if last_layer is not None:
nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
else:
nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0]
g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0]
d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
d_weight = d_weight * self.discriminator_weight
return d_weight
def forward(self, inputs, reconstructions, posteriors, optimizer_idx,
global_step, last_layer=None, cond=None, split="train",
weights=None):
rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
if self.perceptual_weight > 0:
p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous())
rec_loss = rec_loss + self.perceptual_weight * p_loss
nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar
weighted_nll_loss = nll_loss
if weights is not None:
weighted_nll_loss = weights*nll_loss
weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0]
nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
kl_loss = posteriors.kl()
kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]
# now the GAN part
if optimizer_idx == 0:
# generator update
if cond is None:
assert not self.disc_conditional
logits_fake = self.discriminator(reconstructions.contiguous())
else:
assert self.disc_conditional
logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1))
g_loss = -torch.mean(logits_fake)
if self.disc_factor > 0.0:
try:
d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer)
except RuntimeError:
assert not self.training
d_weight = torch.tensor(0.0)
else:
d_weight = torch.tensor(0.0)
disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
loss = weighted_nll_loss + self.kl_weight * kl_loss + d_weight * disc_factor * g_loss
log = {"{}/total_loss".format(split): loss.clone().detach().mean(), "{}/logvar".format(split): self.logvar.detach(),
"{}/kl_loss".format(split): kl_loss.detach().mean(), "{}/nll_loss".format(split): nll_loss.detach().mean(),
"{}/rec_loss".format(split): rec_loss.detach().mean(),
"{}/d_weight".format(split): d_weight.detach(),
"{}/disc_factor".format(split): torch.tensor(disc_factor),
"{}/g_loss".format(split): g_loss.detach().mean(),
}
return loss, log
if optimizer_idx == 1:
# second pass for discriminator update
if cond is None:
logits_real = self.discriminator(inputs.contiguous().detach())
logits_fake = self.discriminator(reconstructions.contiguous().detach())
else:
logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1))
logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1))
disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
d_loss = disc_factor * self.disc_loss(logits_real, logits_fake)
log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(),
"{}/logits_real".format(split): logits_real.detach().mean(),
"{}/logits_fake".format(split): logits_fake.detach().mean()
}
return d_loss, log
import torch
from torch import nn
import torch.nn.functional as F
from einops import repeat
from taming.modules.discriminator.model import NLayerDiscriminator, weights_init
from taming.modules.losses.lpips import LPIPS
from taming.modules.losses.vqperceptual import hinge_d_loss, vanilla_d_loss
def hinge_d_loss_with_exemplar_weights(logits_real, logits_fake, weights):
assert weights.shape[0] == logits_real.shape[0] == logits_fake.shape[0]
loss_real = torch.mean(F.relu(1. - logits_real), dim=[1,2,3])
loss_fake = torch.mean(F.relu(1. + logits_fake), dim=[1,2,3])
loss_real = (weights * loss_real).sum() / weights.sum()
loss_fake = (weights * loss_fake).sum() / weights.sum()
d_loss = 0.5 * (loss_real + loss_fake)
return d_loss
def adopt_weight(weight, global_step, threshold=0, value=0.):
if global_step < threshold:
weight = value
return weight
def measure_perplexity(predicted_indices, n_embed):
# src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py
# eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally
encodings = F.one_hot(predicted_indices, n_embed).float().reshape(-1, n_embed)
avg_probs = encodings.mean(0)
perplexity = (-(avg_probs * torch.log(avg_probs + 1e-10)).sum()).exp()
cluster_use = torch.sum(avg_probs > 0)
return perplexity, cluster_use
def l1(x, y):
return torch.abs(x-y)
def l2(x, y):
return torch.pow((x-y), 2)
class VQLPIPSWithDiscriminator(nn.Module):
def __init__(self, disc_start, codebook_weight=1.0, pixelloss_weight=1.0,
disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0,
perceptual_weight=1.0, use_actnorm=False, disc_conditional=False,
disc_ndf=64, disc_loss="hinge", n_classes=None, perceptual_loss="lpips",
pixel_loss="l1"):
super().__init__()
assert disc_loss in ["hinge", "vanilla"]
assert perceptual_loss in ["lpips", "clips", "dists"]
assert pixel_loss in ["l1", "l2"]
self.codebook_weight = codebook_weight
self.pixel_weight = pixelloss_weight
if perceptual_loss == "lpips":
print(f"{self.__class__.__name__}: Running with LPIPS.")
self.perceptual_loss = LPIPS().eval()
else:
raise ValueError(f"Unknown perceptual loss: >> {perceptual_loss} <<")
self.perceptual_weight = perceptual_weight
if pixel_loss == "l1":
self.pixel_loss = l1
else:
self.pixel_loss = l2
self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels,
n_layers=disc_num_layers,
use_actnorm=use_actnorm,
ndf=disc_ndf
).apply(weights_init)
self.discriminator_iter_start = disc_start
if disc_loss == "hinge":
self.disc_loss = hinge_d_loss
elif disc_loss == "vanilla":
self.disc_loss = vanilla_d_loss
else:
raise ValueError(f"Unknown GAN loss '{disc_loss}'.")
print(f"VQLPIPSWithDiscriminator running with {disc_loss} loss.")
self.disc_factor = disc_factor
self.discriminator_weight = disc_weight
self.disc_conditional = disc_conditional
self.n_classes = n_classes
def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
if last_layer is not None:
nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
else:
nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0]
g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0]
d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
d_weight = d_weight * self.discriminator_weight
return d_weight
def forward(self, codebook_loss, inputs, reconstructions, optimizer_idx,
global_step, last_layer=None, cond=None, split="train", predicted_indices=None):
if not exists(codebook_loss):
codebook_loss = torch.tensor([0.]).to(inputs.device)
#rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
rec_loss = self.pixel_loss(inputs.contiguous(), reconstructions.contiguous())
if self.perceptual_weight > 0:
p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous())
rec_loss = rec_loss + self.perceptual_weight * p_loss
else:
p_loss = torch.tensor([0.0])
nll_loss = rec_loss
#nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
nll_loss = torch.mean(nll_loss)
# now the GAN part
if optimizer_idx == 0:
# generator update
if cond is None:
assert not self.disc_conditional
logits_fake = self.discriminator(reconstructions.contiguous())
else:
assert self.disc_conditional
logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1))
g_loss = -torch.mean(logits_fake)
try:
d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer)
except RuntimeError:
assert not self.training
d_weight = torch.tensor(0.0)
disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
loss = nll_loss + d_weight * disc_factor * g_loss + self.codebook_weight * codebook_loss.mean()
log = {"{}/total_loss".format(split): loss.clone().detach().mean(),
"{}/quant_loss".format(split): codebook_loss.detach().mean(),
"{}/nll_loss".format(split): nll_loss.detach().mean(),
"{}/rec_loss".format(split): rec_loss.detach().mean(),
"{}/p_loss".format(split): p_loss.detach().mean(),
"{}/d_weight".format(split): d_weight.detach(),
"{}/disc_factor".format(split): torch.tensor(disc_factor),
"{}/g_loss".format(split): g_loss.detach().mean(),
}
if predicted_indices is not None:
assert self.n_classes is not None
with torch.no_grad():
perplexity, cluster_usage = measure_perplexity(predicted_indices, self.n_classes)
log[f"{split}/perplexity"] = perplexity
log[f"{split}/cluster_usage"] = cluster_usage
return loss, log
if optimizer_idx == 1:
# second pass for discriminator update
if cond is None:
logits_real = self.discriminator(inputs.contiguous().detach())
logits_fake = self.discriminator(reconstructions.contiguous().detach())
else:
logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1))
logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1))
disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
d_loss = disc_factor * self.disc_loss(logits_real, logits_fake)
log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(),
"{}/logits_real".format(split): logits_real.detach().mean(),
"{}/logits_fake".format(split): logits_fake.detach().mean()
}
return d_loss, log
# based on https://github.com/isl-org/MiDaS
import cv2
import torch
import torch.nn as nn
from torchvision.transforms import Compose
from ldm.modules.midas.midas.dpt_depth import DPTDepthModel
from ldm.modules.midas.midas.midas_net import MidasNet
from ldm.modules.midas.midas.midas_net_custom import MidasNet_small
from ldm.modules.midas.midas.transforms import Resize, NormalizeImage, PrepareForNet
ISL_PATHS = {
"dpt_large": "midas_models/dpt_large-midas-2f21e586.pt",
"dpt_hybrid": "midas_models/dpt_hybrid-midas-501f0c75.pt",
"midas_v21": "",
"midas_v21_small": "",
}
def disabled_train(self, mode=True):
"""Overwrite model.train with this function to make sure train/eval mode
does not change anymore."""
return self
def load_midas_transform(model_type):
# https://github.com/isl-org/MiDaS/blob/master/run.py
# load transform only
if model_type == "dpt_large": # DPT-Large
net_w, net_h = 384, 384
resize_mode = "minimal"
normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
elif model_type == "dpt_hybrid": # DPT-Hybrid
net_w, net_h = 384, 384
resize_mode = "minimal"
normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
elif model_type == "midas_v21":
net_w, net_h = 384, 384
resize_mode = "upper_bound"
normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
elif model_type == "midas_v21_small":
net_w, net_h = 256, 256
resize_mode = "upper_bound"
normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
else:
assert False, f"model_type '{model_type}' not implemented, use: --model_type large"
transform = Compose(
[
Resize(
net_w,
net_h,
resize_target=None,
keep_aspect_ratio=True,
ensure_multiple_of=32,
resize_method=resize_mode,
image_interpolation_method=cv2.INTER_CUBIC,
),
normalization,
PrepareForNet(),
]
)
return transform
def load_model(model_type):
# https://github.com/isl-org/MiDaS/blob/master/run.py
# load network
model_path = ISL_PATHS[model_type]
if model_type == "dpt_large": # DPT-Large
model = DPTDepthModel(
path=model_path,
backbone="vitl16_384",
non_negative=True,
)
net_w, net_h = 384, 384
resize_mode = "minimal"
normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
elif model_type == "dpt_hybrid": # DPT-Hybrid
model = DPTDepthModel(
path=model_path,
backbone="vitb_rn50_384",
non_negative=True,
)
net_w, net_h = 384, 384
resize_mode = "minimal"
normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
elif model_type == "midas_v21":
model = MidasNet(model_path, non_negative=True)
net_w, net_h = 384, 384
resize_mode = "upper_bound"
normalization = NormalizeImage(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
)
elif model_type == "midas_v21_small":
model = MidasNet_small(model_path, features=64, backbone="efficientnet_lite3", exportable=True,
non_negative=True, blocks={'expand': True})
net_w, net_h = 256, 256
resize_mode = "upper_bound"
normalization = NormalizeImage(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
)
else:
print(f"model_type '{model_type}' not implemented, use: --model_type large")
assert False
transform = Compose(
[
Resize(
net_w,
net_h,
resize_target=None,
keep_aspect_ratio=True,
ensure_multiple_of=32,
resize_method=resize_mode,
image_interpolation_method=cv2.INTER_CUBIC,
),
normalization,
PrepareForNet(),
]
)
return model.eval(), transform
class MiDaSInference(nn.Module):
MODEL_TYPES_TORCH_HUB = [
"DPT_Large",
"DPT_Hybrid",
"MiDaS_small"
]
MODEL_TYPES_ISL = [
"dpt_large",
"dpt_hybrid",
"midas_v21",
"midas_v21_small",
]
def __init__(self, model_type):
super().__init__()
assert (model_type in self.MODEL_TYPES_ISL)
model, _ = load_model(model_type)
self.model = model
self.model.train = disabled_train
def forward(self, x):
# x in 0..1 as produced by calling self.transform on a 0..1 float64 numpy array
# NOTE: we expect that the correct transform has been called during dataloading.
with torch.no_grad():
prediction = self.model(x)
prediction = torch.nn.functional.interpolate(
prediction.unsqueeze(1),
size=x.shape[2:],
mode="bicubic",
align_corners=False,
)
assert prediction.shape == (x.shape[0], 1, x.shape[2], x.shape[3])
return prediction
import torch
class BaseModel(torch.nn.Module):
def load(self, path):
"""Load model from file.
Args:
path (str): file path
"""
parameters = torch.load(path, map_location=torch.device('cpu'))
if "optimizer" in parameters:
parameters = parameters["model"]
self.load_state_dict(parameters)
import torch
import torch.nn as nn
from .vit import (
_make_pretrained_vitb_rn50_384,
_make_pretrained_vitl16_384,
_make_pretrained_vitb16_384,
forward_vit,
)
def _make_encoder(backbone, features, use_pretrained, groups=1, expand=False, exportable=True, hooks=None, use_vit_only=False, use_readout="ignore",):
if backbone == "vitl16_384":
pretrained = _make_pretrained_vitl16_384(
use_pretrained, hooks=hooks, use_readout=use_readout
)
scratch = _make_scratch(
[256, 512, 1024, 1024], features, groups=groups, expand=expand
) # ViT-L/16 - 85.0% Top1 (backbone)
elif backbone == "vitb_rn50_384":
pretrained = _make_pretrained_vitb_rn50_384(
use_pretrained,
hooks=hooks,
use_vit_only=use_vit_only,
use_readout=use_readout,
)
scratch = _make_scratch(
[256, 512, 768, 768], features, groups=groups, expand=expand
) # ViT-H/16 - 85.0% Top1 (backbone)
elif backbone == "vitb16_384":
pretrained = _make_pretrained_vitb16_384(
use_pretrained, hooks=hooks, use_readout=use_readout
)
scratch = _make_scratch(
[96, 192, 384, 768], features, groups=groups, expand=expand
) # ViT-B/16 - 84.6% Top1 (backbone)
elif backbone == "resnext101_wsl":
pretrained = _make_pretrained_resnext101_wsl(use_pretrained)
scratch = _make_scratch([256, 512, 1024, 2048], features, groups=groups, expand=expand) # efficientnet_lite3
elif backbone == "efficientnet_lite3":
pretrained = _make_pretrained_efficientnet_lite3(use_pretrained, exportable=exportable)
scratch = _make_scratch([32, 48, 136, 384], features, groups=groups, expand=expand) # efficientnet_lite3
else:
print(f"Backbone '{backbone}' not implemented")
assert False
return pretrained, scratch
def _make_scratch(in_shape, out_shape, groups=1, expand=False):
scratch = nn.Module()
out_shape1 = out_shape
out_shape2 = out_shape
out_shape3 = out_shape
out_shape4 = out_shape
if expand==True:
out_shape1 = out_shape
out_shape2 = out_shape*2
out_shape3 = out_shape*4
out_shape4 = out_shape*8
scratch.layer1_rn = nn.Conv2d(
in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
)
scratch.layer2_rn = nn.Conv2d(
in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
)
scratch.layer3_rn = nn.Conv2d(
in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
)
scratch.layer4_rn = nn.Conv2d(
in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
)
return scratch
def _make_pretrained_efficientnet_lite3(use_pretrained, exportable=False):
efficientnet = torch.hub.load(
"rwightman/gen-efficientnet-pytorch",
"tf_efficientnet_lite3",
pretrained=use_pretrained,
exportable=exportable
)
return _make_efficientnet_backbone(efficientnet)
def _make_efficientnet_backbone(effnet):
pretrained = nn.Module()
pretrained.layer1 = nn.Sequential(
effnet.conv_stem, effnet.bn1, effnet.act1, *effnet.blocks[0:2]
)
pretrained.layer2 = nn.Sequential(*effnet.blocks[2:3])
pretrained.layer3 = nn.Sequential(*effnet.blocks[3:5])
pretrained.layer4 = nn.Sequential(*effnet.blocks[5:9])
return pretrained
def _make_resnet_backbone(resnet):
pretrained = nn.Module()
pretrained.layer1 = nn.Sequential(
resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool, resnet.layer1
)
pretrained.layer2 = resnet.layer2
pretrained.layer3 = resnet.layer3
pretrained.layer4 = resnet.layer4
return pretrained
def _make_pretrained_resnext101_wsl(use_pretrained):
resnet = torch.hub.load("facebookresearch/WSL-Images", "resnext101_32x8d_wsl")
return _make_resnet_backbone(resnet)
class Interpolate(nn.Module):
"""Interpolation module.
"""
def __init__(self, scale_factor, mode, align_corners=False):
"""Init.
Args:
scale_factor (float): scaling
mode (str): interpolation mode
"""
super(Interpolate, self).__init__()
self.interp = nn.functional.interpolate
self.scale_factor = scale_factor
self.mode = mode
self.align_corners = align_corners
def forward(self, x):
"""Forward pass.
Args:
x (tensor): input
Returns:
tensor: interpolated data
"""
x = self.interp(
x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners
)
return x
class ResidualConvUnit(nn.Module):
"""Residual convolution module.
"""
def __init__(self, features):
"""Init.
Args:
features (int): number of features
"""
super().__init__()
self.conv1 = nn.Conv2d(
features, features, kernel_size=3, stride=1, padding=1, bias=True
)
self.conv2 = nn.Conv2d(
features, features, kernel_size=3, stride=1, padding=1, bias=True
)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
"""Forward pass.
Args:
x (tensor): input
Returns:
tensor: output
"""
out = self.relu(x)
out = self.conv1(out)
out = self.relu(out)
out = self.conv2(out)
return out + x
class FeatureFusionBlock(nn.Module):
"""Feature fusion block.
"""
def __init__(self, features):
"""Init.
Args:
features (int): number of features
"""
super(FeatureFusionBlock, self).__init__()
self.resConfUnit1 = ResidualConvUnit(features)
self.resConfUnit2 = ResidualConvUnit(features)
def forward(self, *xs):
"""Forward pass.
Returns:
tensor: output
"""
output = xs[0]
if len(xs) == 2:
output += self.resConfUnit1(xs[1])
output = self.resConfUnit2(output)
output = nn.functional.interpolate(
output, scale_factor=2, mode="bilinear", align_corners=True
)
return output
class ResidualConvUnit_custom(nn.Module):
"""Residual convolution module.
"""
def __init__(self, features, activation, bn):
"""Init.
Args:
features (int): number of features
"""
super().__init__()
self.bn = bn
self.groups=1
self.conv1 = nn.Conv2d(
features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups
)
self.conv2 = nn.Conv2d(
features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups
)
if self.bn==True:
self.bn1 = nn.BatchNorm2d(features)
self.bn2 = nn.BatchNorm2d(features)
self.activation = activation
self.skip_add = nn.quantized.FloatFunctional()
def forward(self, x):
"""Forward pass.
Args:
x (tensor): input
Returns:
tensor: output
"""
out = self.activation(x)
out = self.conv1(out)
if self.bn==True:
out = self.bn1(out)
out = self.activation(out)
out = self.conv2(out)
if self.bn==True:
out = self.bn2(out)
if self.groups > 1:
out = self.conv_merge(out)
return self.skip_add.add(out, x)
# return out + x
class FeatureFusionBlock_custom(nn.Module):
"""Feature fusion block.
"""
def __init__(self, features, activation, deconv=False, bn=False, expand=False, align_corners=True):
"""Init.
Args:
features (int): number of features
"""
super(FeatureFusionBlock_custom, self).__init__()
self.deconv = deconv
self.align_corners = align_corners
self.groups=1
self.expand = expand
out_features = features
if self.expand==True:
out_features = features//2
self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1)
self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn)
self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn)
self.skip_add = nn.quantized.FloatFunctional()
def forward(self, *xs):
"""Forward pass.
Returns:
tensor: output
"""
output = xs[0]
if len(xs) == 2:
res = self.resConfUnit1(xs[1])
output = self.skip_add.add(output, res)
# output += res
output = self.resConfUnit2(output)
output = nn.functional.interpolate(
output, scale_factor=2, mode="bilinear", align_corners=self.align_corners
)
output = self.out_conv(output)
return output
import torch
import torch.nn as nn
import torch.nn.functional as F
from .base_model import BaseModel
from .blocks import (
FeatureFusionBlock,
FeatureFusionBlock_custom,
Interpolate,
_make_encoder,
forward_vit,
)
def _make_fusion_block(features, use_bn):
return FeatureFusionBlock_custom(
features,
nn.ReLU(False),
deconv=False,
bn=use_bn,
expand=False,
align_corners=True,
)
class DPT(BaseModel):
def __init__(
self,
head,
features=256,
backbone="vitb_rn50_384",
readout="project",
channels_last=False,
use_bn=False,
):
super(DPT, self).__init__()
self.channels_last = channels_last
hooks = {
"vitb_rn50_384": [0, 1, 8, 11],
"vitb16_384": [2, 5, 8, 11],
"vitl16_384": [5, 11, 17, 23],
}
# Instantiate backbone and reassemble blocks
self.pretrained, self.scratch = _make_encoder(
backbone,
features,
False, # Set to true of you want to train from scratch, uses ImageNet weights
groups=1,
expand=False,
exportable=False,
hooks=hooks[backbone],
use_readout=readout,
)
self.scratch.refinenet1 = _make_fusion_block(features, use_bn)
self.scratch.refinenet2 = _make_fusion_block(features, use_bn)
self.scratch.refinenet3 = _make_fusion_block(features, use_bn)
self.scratch.refinenet4 = _make_fusion_block(features, use_bn)
self.scratch.output_conv = head
def forward(self, x):
if self.channels_last == True:
x.contiguous(memory_format=torch.channels_last)
layer_1, layer_2, layer_3, layer_4 = forward_vit(self.pretrained, x)
layer_1_rn = self.scratch.layer1_rn(layer_1)
layer_2_rn = self.scratch.layer2_rn(layer_2)
layer_3_rn = self.scratch.layer3_rn(layer_3)
layer_4_rn = self.scratch.layer4_rn(layer_4)
path_4 = self.scratch.refinenet4(layer_4_rn)
path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
out = self.scratch.output_conv(path_1)
return out
class DPTDepthModel(DPT):
def __init__(self, path=None, non_negative=True, **kwargs):
features = kwargs["features"] if "features" in kwargs else 256
head = nn.Sequential(
nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1),
Interpolate(scale_factor=2, mode="bilinear", align_corners=True),
nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1),
nn.ReLU(True),
nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
nn.ReLU(True) if non_negative else nn.Identity(),
nn.Identity(),
)
super().__init__(head, **kwargs)
if path is not None:
self.load(path)
def forward(self, x):
return super().forward(x).squeeze(dim=1)
"""MidashNet: Network for monocular depth estimation trained by mixing several datasets.
This file contains code that is adapted from
https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py
"""
import torch
import torch.nn as nn
from .base_model import BaseModel
from .blocks import FeatureFusionBlock, Interpolate, _make_encoder
class MidasNet(BaseModel):
"""Network for monocular depth estimation.
"""
def __init__(self, path=None, features=256, non_negative=True):
"""Init.
Args:
path (str, optional): Path to saved model. Defaults to None.
features (int, optional): Number of features. Defaults to 256.
backbone (str, optional): Backbone network for encoder. Defaults to resnet50
"""
print("Loading weights: ", path)
super(MidasNet, self).__init__()
use_pretrained = False if path is None else True
self.pretrained, self.scratch = _make_encoder(backbone="resnext101_wsl", features=features, use_pretrained=use_pretrained)
self.scratch.refinenet4 = FeatureFusionBlock(features)
self.scratch.refinenet3 = FeatureFusionBlock(features)
self.scratch.refinenet2 = FeatureFusionBlock(features)
self.scratch.refinenet1 = FeatureFusionBlock(features)
self.scratch.output_conv = nn.Sequential(
nn.Conv2d(features, 128, kernel_size=3, stride=1, padding=1),
Interpolate(scale_factor=2, mode="bilinear"),
nn.Conv2d(128, 32, kernel_size=3, stride=1, padding=1),
nn.ReLU(True),
nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
nn.ReLU(True) if non_negative else nn.Identity(),
)
if path:
self.load(path)
def forward(self, x):
"""Forward pass.
Args:
x (tensor): input data (image)
Returns:
tensor: depth
"""
layer_1 = self.pretrained.layer1(x)
layer_2 = self.pretrained.layer2(layer_1)
layer_3 = self.pretrained.layer3(layer_2)
layer_4 = self.pretrained.layer4(layer_3)
layer_1_rn = self.scratch.layer1_rn(layer_1)
layer_2_rn = self.scratch.layer2_rn(layer_2)
layer_3_rn = self.scratch.layer3_rn(layer_3)
layer_4_rn = self.scratch.layer4_rn(layer_4)
path_4 = self.scratch.refinenet4(layer_4_rn)
path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
out = self.scratch.output_conv(path_1)
return torch.squeeze(out, dim=1)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment