Unverified Commit 6c4c6a04 authored by Fazzie-Maqianli's avatar Fazzie-Maqianli Committed by GitHub
Browse files

Merge pull request #2120 from Fazziekey/example/stablediffusion-v2

[example] support stable diffusion v2
parents 5efda697 cea4292a
import torch
import numpy as np
def append_dims(x, target_dims):
"""Appends dimensions to the end of a tensor until it has target_dims dimensions.
From https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/utils.py"""
dims_to_append = target_dims - x.ndim
if dims_to_append < 0:
raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less')
return x[(...,) + (None,) * dims_to_append]
def norm_thresholding(x0, value):
s = append_dims(x0.pow(2).flatten(1).mean(1).sqrt().clamp(min=value), x0.ndim)
return x0 * (value / s)
def spatial_norm_thresholding(x0, value):
# b c h w
s = x0.pow(2).mean(1, keepdim=True).sqrt().clamp(min=value)
return x0 * (value / s)
\ No newline at end of file
...@@ -4,24 +4,17 @@ import torch ...@@ -4,24 +4,17 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch import nn, einsum from torch import nn, einsum
from einops import rearrange, repeat from einops import rearrange, repeat
from typing import Optional, Any
from ldm.modules.diffusionmodules.util import checkpoint
from torch.utils import checkpoint
try: try:
from ldm.modules.flash_attention import flash_attention_qkv, flash_attention_q_kv import xformers
FlASH_AVAILABLE = True import xformers.ops
XFORMERS_IS_AVAILBLE = True
except: except:
FlASH_AVAILABLE = False XFORMERS_IS_AVAILBLE = False
USE_FLASH = False
def enable_flash_attention():
global USE_FLASH
USE_FLASH = True
if FlASH_AVAILABLE is False:
print("Please install flash attention to activate new attention kernel.\n" +
"Use \'pip install git+https://github.com/HazyResearch/flash-attention.git@c422fee3776eb3ea24e011ef641fd5fbeb212623#egg=flash_attn\'")
def exists(val): def exists(val):
...@@ -93,25 +86,6 @@ def Normalize(in_channels): ...@@ -93,25 +86,6 @@ def Normalize(in_channels):
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
class LinearAttention(nn.Module):
def __init__(self, dim, heads=4, dim_head=32):
super().__init__()
self.heads = heads
hidden_dim = dim_head * heads
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False)
self.to_out = nn.Conv2d(hidden_dim, dim, 1)
def forward(self, x):
b, c, h, w = x.shape
qkv = self.to_qkv(x)
q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3)
k = k.softmax(dim=-1)
context = torch.einsum('bhdn,bhen->bhde', k, v)
out = torch.einsum('bhde,bhdn->bhen', context, q)
out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w)
return self.to_out(out)
class SpatialSelfAttention(nn.Module): class SpatialSelfAttention(nn.Module):
def __init__(self, in_channels): def __init__(self, in_channels):
super().__init__() super().__init__()
...@@ -184,87 +158,113 @@ class CrossAttention(nn.Module): ...@@ -184,87 +158,113 @@ class CrossAttention(nn.Module):
) )
def forward(self, x, context=None, mask=None): def forward(self, x, context=None, mask=None):
h = self.heads
q = self.to_q(x) q = self.to_q(x)
context = default(context, x) context = default(context, x)
k = self.to_k(context) k = self.to_k(context)
v = self.to_v(context) v = self.to_v(context)
dim_head = q.shape[-1] / self.heads
if USE_FLASH and FlASH_AVAILABLE and q.dtype in (torch.float16, torch.bfloat16) and \
dim_head <= 128 and (dim_head % 8) == 0:
# print("in flash")
if q.shape[1] == k.shape[1]:
out = self._flash_attention_qkv(q, k, v)
else:
out = self._flash_attention_q_kv(q, k, v)
else:
out = self._native_attention(q, k, v, self.heads, mask)
return self.to_out(out)
def _native_attention(self, q, k, v, h, mask):
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
del q, k
if exists(mask): if exists(mask):
mask = rearrange(mask, 'b ... -> b (...)') mask = rearrange(mask, 'b ... -> b (...)')
max_neg_value = -torch.finfo(sim.dtype).max max_neg_value = -torch.finfo(sim.dtype).max
mask = repeat(mask, 'b j -> (b h) () j', h=h) mask = repeat(mask, 'b j -> (b h) () j', h=h)
sim.masked_fill_(~mask, max_neg_value) sim.masked_fill_(~mask, max_neg_value)
# attention, what we cannot get enough of # attention, what we cannot get enough of
out = sim.softmax(dim=-1) sim = sim.softmax(dim=-1)
out = einsum('b i j, b j d -> b i d', out, v)
out = einsum('b i j, b j d -> b i d', sim, v)
out = rearrange(out, '(b h) n d -> b n (h d)', h=h) out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
return out return self.to_out(out)
def _flash_attention_qkv(self, q, k, v):
qkv = torch.stack([q, k, v], dim=2) class MemoryEfficientCrossAttention(nn.Module):
b = qkv.shape[0] # https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
n = qkv.shape[1] def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
qkv = rearrange(qkv, 'b n t (h d) -> (b n) t h d', h=self.heads) super().__init__()
out = flash_attention_qkv(qkv, self.scale, b, n) print(f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using "
out = rearrange(out, '(b n) h d -> b n (h d)', b=b, h=self.heads) f"{heads} heads.")
return out inner_dim = dim_head * heads
context_dim = default(context_dim, query_dim)
def _flash_attention_q_kv(self, q, k, v):
kv = torch.stack([k, v], dim=2) self.heads = heads
b = q.shape[0] self.dim_head = dim_head
q_seqlen = q.shape[1]
kv_seqlen = kv.shape[1] self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
q = rearrange(q, 'b n (h d) -> (b n) h d', h=self.heads) self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
kv = rearrange(kv, 'b n t (h d) -> (b n) t h d', h=self.heads) self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
out = flash_attention_q_kv(q, kv, self.scale, b, q_seqlen, kv_seqlen)
out = rearrange(out, '(b n) h d -> b n (h d)', b=b, h=self.heads) self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
return out self.attention_op: Optional[Any] = None
def forward(self, x, context=None, mask=None):
q = self.to_q(x)
context = default(context, x)
k = self.to_k(context)
v = self.to_v(context)
b, _, _ = q.shape
q, k, v = map(
lambda t: t.unsqueeze(3)
.reshape(b, t.shape[1], self.heads, self.dim_head)
.permute(0, 2, 1, 3)
.reshape(b * self.heads, t.shape[1], self.dim_head)
.contiguous(),
(q, k, v),
)
# actually compute the attention, what we cannot get enough of
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op)
if exists(mask):
raise NotImplementedError
out = (
out.unsqueeze(0)
.reshape(b, self.heads, out.shape[1], self.dim_head)
.permute(0, 2, 1, 3)
.reshape(b, out.shape[1], self.heads * self.dim_head)
)
return self.to_out(out)
class BasicTransformerBlock(nn.Module): class BasicTransformerBlock(nn.Module):
def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, use_checkpoint=False): ATTENTION_MODES = {
"softmax": CrossAttention, # vanilla attention
"softmax-xformers": MemoryEfficientCrossAttention
}
def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True,
disable_self_attn=False):
super().__init__() super().__init__()
self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout) # is a self-attention attn_mode = "softmax-xformers" if XFORMERS_IS_AVAILBLE else "softmax"
assert attn_mode in self.ATTENTION_MODES
attn_cls = self.ATTENTION_MODES[attn_mode]
self.disable_self_attn = disable_self_attn
self.attn1 = attn_cls(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout,
context_dim=context_dim if self.disable_self_attn else None) # is a self-attention if not self.disable_self_attn
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim, self.attn2 = attn_cls(query_dim=dim, context_dim=context_dim,
heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none
self.norm1 = nn.LayerNorm(dim) self.norm1 = nn.LayerNorm(dim)
self.norm2 = nn.LayerNorm(dim) self.norm2 = nn.LayerNorm(dim)
self.norm3 = nn.LayerNorm(dim) self.norm3 = nn.LayerNorm(dim)
self.use_checkpoint = use_checkpoint self.checkpoint = checkpoint
def forward(self, x, context=None): def forward(self, x, context=None):
return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint)
if self.use_checkpoint:
return checkpoint(self._forward, x, context)
else:
return self._forward(x, context)
def _forward(self, x, context=None): def _forward(self, x, context=None):
x = self.attn1(self.norm1(x)) + x x = self.attn1(self.norm1(x), context=context if self.disable_self_attn else None) + x
x = self.attn2(self.norm2(x), context=context) + x x = self.attn2(self.norm2(x), context=context) + x
x = self.ff(self.norm3(x)) + x x = self.ff(self.norm3(x)) + x
return x return x
class SpatialTransformer(nn.Module): class SpatialTransformer(nn.Module):
""" """
Transformer block for image-like data. Transformer block for image-like data.
...@@ -272,43 +272,60 @@ class SpatialTransformer(nn.Module): ...@@ -272,43 +272,60 @@ class SpatialTransformer(nn.Module):
and reshape to b, t, d. and reshape to b, t, d.
Then apply standard transformer action. Then apply standard transformer action.
Finally, reshape to image Finally, reshape to image
NEW: use_linear for more efficiency instead of the 1x1 convs
""" """
def __init__(self, in_channels, n_heads, d_head, def __init__(self, in_channels, n_heads, d_head,
depth=1, dropout=0., context_dim=None, use_checkpoint=False): depth=1, dropout=0., context_dim=None,
disable_self_attn=False, use_linear=False,
use_checkpoint=True):
super().__init__() super().__init__()
if exists(context_dim) and not isinstance(context_dim, list):
context_dim = [context_dim]
self.in_channels = in_channels self.in_channels = in_channels
inner_dim = n_heads * d_head inner_dim = n_heads * d_head
self.norm = Normalize(in_channels) self.norm = Normalize(in_channels)
if not use_linear:
self.proj_in = nn.Conv2d(in_channels, self.proj_in = nn.Conv2d(in_channels,
inner_dim, inner_dim,
kernel_size=1, kernel_size=1,
stride=1, stride=1,
padding=0) padding=0)
else:
self.proj_in = nn.Linear(in_channels, inner_dim)
self.transformer_blocks = nn.ModuleList( self.transformer_blocks = nn.ModuleList(
[BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim, use_checkpoint=use_checkpoint) [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim[d],
disable_self_attn=disable_self_attn, checkpoint=use_checkpoint)
for d in range(depth)] for d in range(depth)]
) )
if not use_linear:
self.proj_out = zero_module(nn.Conv2d(inner_dim, self.proj_out = zero_module(nn.Conv2d(inner_dim,
in_channels, in_channels,
kernel_size=1, kernel_size=1,
stride=1, stride=1,
padding=0)) padding=0))
else:
self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
self.use_linear = use_linear
def forward(self, x, context=None): def forward(self, x, context=None):
# note: if no context is given, cross-attention defaults to self-attention # note: if no context is given, cross-attention defaults to self-attention
if not isinstance(context, list):
context = [context]
b, c, h, w = x.shape b, c, h, w = x.shape
x_in = x x_in = x
x = self.norm(x) x = self.norm(x)
if not self.use_linear:
x = self.proj_in(x) x = self.proj_in(x)
x = rearrange(x, 'b c h w -> b (h w) c') x = rearrange(x, 'b c h w -> b (h w) c').contiguous()
x = x.contiguous() if self.use_linear:
for block in self.transformer_blocks: x = self.proj_in(x)
x = block(x, context=context) for i, block in enumerate(self.transformer_blocks):
x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w) x = block(x, context=context[i])
x = x.contiguous() if self.use_linear:
x = self.proj_out(x)
x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous()
if not self.use_linear:
x = self.proj_out(x) x = self.proj_out(x)
return x + x_in return x + x_in
from ldm.modules.losses.contperceptual import LPIPSWithDiscriminator
\ No newline at end of file
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment