"benchmark/vscode:/vscode.git/clone" did not exist on "49afb3d9d9deedf6dea3a6dd5c50e85e7d8bcb07"
modeling_wuerstchen_common.py 3.05 KB
Newer Older
Kashif Rasul's avatar
Kashif Rasul committed
1
2
3
4
import torch
import torch.nn as nn

from ...models.attention_processor import Attention
5
6
from ...models.lora import LoRACompatibleConv, LoRACompatibleLinear
from ...utils import USE_PEFT_BACKEND
Kashif Rasul's avatar
Kashif Rasul committed
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21


class WuerstchenLayerNorm(nn.LayerNorm):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def forward(self, x):
        x = x.permute(0, 2, 3, 1)
        x = super().forward(x)
        return x.permute(0, 3, 1, 2)


class TimestepBlock(nn.Module):
    def __init__(self, c, c_timestep):
        super().__init__()
22
23
        linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear
        self.mapper = linear_cls(c_timestep, c * 2)
Kashif Rasul's avatar
Kashif Rasul committed
24
25
26
27
28
29
30
31
32

    def forward(self, x, t):
        a, b = self.mapper(t)[:, :, None, None].chunk(2, dim=1)
        return x * (1 + a) + b


class ResBlock(nn.Module):
    def __init__(self, c, c_skip=0, kernel_size=3, dropout=0.0):
        super().__init__()
33
34
35
36
37

        conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv
        linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear

        self.depthwise = conv_cls(c + c_skip, c, kernel_size=kernel_size, padding=kernel_size // 2, groups=c)
Kashif Rasul's avatar
Kashif Rasul committed
38
39
        self.norm = WuerstchenLayerNorm(c, elementwise_affine=False, eps=1e-6)
        self.channelwise = nn.Sequential(
40
            linear_cls(c, c * 4), nn.GELU(), GlobalResponseNorm(c * 4), nn.Dropout(dropout), linear_cls(c * 4, c)
Kashif Rasul's avatar
Kashif Rasul committed
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
        )

    def forward(self, x, x_skip=None):
        x_res = x
        if x_skip is not None:
            x = torch.cat([x, x_skip], dim=1)
        x = self.norm(self.depthwise(x)).permute(0, 2, 3, 1)
        x = self.channelwise(x).permute(0, 3, 1, 2)
        return x + x_res


# from https://github.com/facebookresearch/ConvNeXt-V2/blob/3608f67cc1dae164790c5d0aead7bf2d73d9719b/models/utils.py#L105
class GlobalResponseNorm(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.gamma = nn.Parameter(torch.zeros(1, 1, 1, dim))
        self.beta = nn.Parameter(torch.zeros(1, 1, 1, dim))

    def forward(self, x):
        agg_norm = torch.norm(x, p=2, dim=(1, 2), keepdim=True)
        stand_div_norm = agg_norm / (agg_norm.mean(dim=-1, keepdim=True) + 1e-6)
        return self.gamma * (x * stand_div_norm) + self.beta + x


class AttnBlock(nn.Module):
    def __init__(self, c, c_cond, nhead, self_attn=True, dropout=0.0):
        super().__init__()
68
69
70

        linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear

Kashif Rasul's avatar
Kashif Rasul committed
71
72
73
        self.self_attn = self_attn
        self.norm = WuerstchenLayerNorm(c, elementwise_affine=False, eps=1e-6)
        self.attention = Attention(query_dim=c, heads=nhead, dim_head=c // nhead, dropout=dropout, bias=True)
74
        self.kv_mapper = nn.Sequential(nn.SiLU(), linear_cls(c_cond, c))
Kashif Rasul's avatar
Kashif Rasul committed
75
76
77
78
79
80
81
82
83

    def forward(self, x, kv):
        kv = self.kv_mapper(kv)
        norm_x = self.norm(x)
        if self.self_attn:
            batch_size, channel, _, _ = x.shape
            kv = torch.cat([norm_x.view(batch_size, channel, -1).transpose(1, 2), kv], dim=1)
        x = x + self.attention(norm_x, encoder_hidden_states=kv)
        return x