sfr_lftg.py 2.4 KB
Newer Older
yangzhong's avatar
yangzhong committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import torch
import torch.nn as nn
import xformers.ops

# spatial feature refiner
class SpatialFeatureRefiner(nn.Module):
    def __init__(self, hidden_channels):
        super(SpatialFeatureRefiner, self).__init__()

        # high-frequency branch
        self.hf_linear = nn.Linear(hidden_channels, hidden_channels * 2)

        # low-frequency branch
        self.lf_linear = nn.Linear(hidden_channels, hidden_channels * 2)

        # fusion
        self.gelu = nn.GELU()
        self.fusion_linear = nn.Linear(hidden_channels * 2, hidden_channels)

    def forward(self, hf_feature, lf_feature, x):
        
        # high-frequency branch
        hf_feature = self.hf_linear(hf_feature)
        scale_hf, shift_hf = hf_feature.chunk(2, dim=-1)
        x_hf = x * scale_hf + shift_hf

        # low-frequency branch
        lf_feature = self.lf_linear(lf_feature)
        scale_lf, shift_lf = lf_feature.chunk(2, dim=-1)
        x_lf = x * scale_lf + shift_lf

        # fusion
        x_fusion = torch.cat([x_hf, x_lf], dim=-1)
        x_fusion = self.gelu(x_fusion)
        x_fusion = self.fusion_linear(x_fusion)

        return x_fusion

    
# low-frequency temporal guider
class LFTemporalGuider(nn.Module):
    def __init__(self, d_model, num_heads, attn_drop=0.0, proj_drop=0.0):
        super(LFTemporalGuider, self).__init__()
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"

        self.d_model = d_model
        self.num_heads = num_heads
        self.head_dim = d_model // num_heads

        self.q_linear = nn.Linear(d_model, d_model)
        self.kv_linear = nn.Linear(d_model, d_model * 2)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(d_model, d_model)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x, cond, mask=None):
        # query/value: img tokens; key: condition; mask: if padding tokens
        B, N, C = x.shape

        q = self.q_linear(x).view(1, -1, self.num_heads, self.head_dim)
        kv = self.kv_linear(cond).view(1, -1, 2, self.num_heads, self.head_dim)
        k, v = kv.unbind(2)

        attn_bias = None
        if mask is not None:
            attn_bias = xformers.ops.fmha.BlockDiagonalMask.from_seqlens([N] * B, mask)
        x = xformers.ops.memory_efficient_attention(q, k, v, p=self.attn_drop.p, attn_bias=attn_bias)

        x = x.view(B, -1, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x