fdie_arch.py 7.31 KB
Newer Older
yangzhong's avatar
yangzhong committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
import torch
import torch.nn as nn
from opensora.models.vsr.safmn_arch import SAFMN
import torch.nn.functional as F
from einops import rearrange
from timm.models.vision_transformer import Mlp
from opensora.models.layers.blocks import (
    Attention,
    MultiHeadCrossAttention,
    PatchEmbed3D,
    get_1d_sincos_pos_embed,
    get_2d_sincos_pos_embed,
    get_layernorm,
)


# high pass filter
def high_pass_filter(x, kernel_size=21):
    """
    对输入张量进行高通滤波,提取高频和低频部分。
    
    参数:
    x (torch.Tensor): 形状为 [B, C, T, H, W] 的输入张量,值范围在 [-1, 1]。
    kernel_size (int): 高斯核的大小。
    
    返回:
    high_freq (torch.Tensor): 高频部分,形状与 x 相同。
    low_freq (torch.Tensor): 低频部分,形状与 x 相同。
    """
    # 计算sigma值
    sigma = kernel_size / 6
    
    # 确定输入张量的设备
    device, dtype  = x.device, x.dtype
    
    # 转换维度 [B, C, T, H, W] -> [B*T, C, H, W]
    B, C, T, H, W = x.shape
    x_reshaped = x.contiguous().view(B * T, C, H, W)
    
    # 创建高斯核
    def get_gaussian_kernel(kernel_size, sigma):
        axis = torch.arange(kernel_size, dtype=dtype, device=device) - kernel_size // 2
        gaussian = torch.exp(-0.5 * (axis / sigma) ** 2)
        gaussian /= gaussian.sum()
        return gaussian
    
    gaussian_1d = get_gaussian_kernel(kernel_size, sigma)
    gaussian_2d = torch.outer(gaussian_1d, gaussian_1d)
    gaussian_3d = gaussian_2d.unsqueeze(0).unsqueeze(0)  # [1, 1, H, W]
    
    # 将高斯核扩展到四维
    gaussian_kernel = gaussian_3d.expand(C, 1, kernel_size, kernel_size)
    
    # 使用F.conv2d进行卷积操作
    padding = kernel_size // 2
    
    # 计算低频部分
    low_freq_reshaped = F.conv2d(x_reshaped, gaussian_kernel, padding=padding, groups=C)
    
    # 计算高频部分
    high_freq_reshaped = x_reshaped - low_freq_reshaped
    
    # 转换回原始维度 [B*T, C, H, W] -> [B, C, T, H, W]
    low_freq = low_freq_reshaped.view(B, C, T, H, W)
    high_freq = high_freq_reshaped.view(B, C, T, H, W)
    
    return high_freq, low_freq


# depth-wise separable convoluiton
class DepthWiseSeparableResBlock(nn.Module):
    def __init__(self, in_channels, kernel_size=3, stride=1, padding=1, bias=False):
        super(DepthWiseSeparableResBlock, self).__init__()

        self.dwconv1 = nn.Conv2d(in_channels, in_channels, kernel_size, stride, padding, bias=bias) # groups=in_channels, 
        # self.conv1 = nn.Conv2d(in_channels, in_channels, 1, bias=bias)

        self.gelu = nn.GELU()

        self.dwconv2 = nn.Conv2d(in_channels, in_channels, kernel_size, stride, padding, bias=bias) # groups=in_channels, 
        # self.conv2 = nn.Conv2d(in_channels, in_channels, 1, bias=bias)

    def forward(self, x):
        residual = x

        out = self.dwconv1(x)
        # out = self.conv1(out)
        out = self.gelu(out)

        out = self.dwconv2(out)
        # out = self.conv2(out)

        out += residual

        return out
    
# temporal transformer block
class TemporalTransformerBlock(nn.Module):
    def __init__(self):
        super(TemporalTransformerBlock, self).__init__()

        # temporal norm
        self.temporal_norm = get_layernorm(1152, eps=1e-6, affine=False, use_kernel=True)

        # temporal self-attention
        self.temporal_attn = Attention(
            dim=1152,
            num_heads=16,
            qkv_bias=True,
            enable_flashattn=True)

        # ffn
        self.temporal_ffn = Mlp(in_features=1152, hidden_features=4608, out_features=1152, act_layer=nn.GELU)

    def forward(self, x):
        residual = x

        out = self.temporal_norm(x)
        out = self.temporal_attn(out)
        out = self.temporal_ffn(out)

        out += residual

        return out



# frequency-decoupled information extractor
class FrequencyDecoupledInfoExtractor(nn.Module):
    def __init__(self, in_channels, hidden_channels, kernel_size=3, stride=1, padding=1, bias=True):
        super(FrequencyDecoupledInfoExtractor, self).__init__()

        ### spatial branch ###
        self.safmn = SAFMN(dim=128, n_blocks=16, ffn_scale=2.0, upscaling_factor=4, use_res=True)
        state_dict = torch.load('/mnt/bn/videodataset/VSR/pretrained_models/SAFMN_L_Real_LSDIR_x4-v2.pth')
        self.safmn.load_state_dict(state_dict['params_ema'], strict=True)

        # high-frequency branch
        # self.hf_convin = nn.Conv2d(in_channels, hidden_channels, kernel_size, stride, padding, bias=bias)
        # self.hf_convout = nn.Conv2d(hidden_channels, in_channels, kernel_size, stride, padding, bias=bias)
        # hf_layer = []
        # for i in range(8):
        #     hf_layer.append(DepthWiseSeparableResBlock(hidden_channels, kernel_size, stride=1, padding=1, bias=bias))
        # self.hf_body = nn.Sequential(*hf_layer)
        self.safmn1 = SAFMN(dim=72, n_blocks=8, ffn_scale=2.0, upscaling_factor=1, in_dim=6, use_res=True)

        # low-frequency branch
        # self.lf_convin = nn.Conv2d(in_channels, hidden_channels, kernel_size, stride, padding, bias=bias)
        # self.lf_convout = nn.Conv2d(hidden_channels, in_channels, kernel_size, stride, padding, bias=bias)
        # lf_layer = []
        # for i in range(8):
        #     lf_layer.append(DepthWiseSeparableResBlock(hidden_channels, kernel_size, stride=1, padding=1, bias=bias))
        # self.lf_body = nn.Sequential(*lf_layer)
        self.safmn2 = SAFMN(dim=72, n_blocks=8, ffn_scale=2.0, upscaling_factor=1, in_dim=6, use_res=True)

        ### temporal branch ###
        layer = []
        for i in range(3):
            layer.append(TemporalTransformerBlock())
        self.temporal_body = nn.Sequential(*layer)


    def get_temporal_pos_embed(self):
        pos_embed = get_1d_sincos_pos_embed(
            embed_dim=1152,
            length=16,
            scale=1.0,
        )
        pos_embed = torch.from_numpy(pos_embed).float().unsqueeze(0).requires_grad_(False)
        return pos_embed
    

    def spatial_forward(self, x):
        with torch.no_grad():
            x = rearrange(x, 'B C T H W -> (B T) C H W')
            x = F.interpolate(x, scale_factor=1/4, mode='bilinear')
            clean_image = self.safmn(x)
            clean_image = rearrange(clean_image, '(B T) C H W -> B C T H W', T=16)
            high_freq, low_freq = high_pass_filter(clean_image)
            fea_decouple = torch.cat([high_freq, low_freq], dim=1)
            fea_decouple = rearrange(fea_decouple, 'B C T H W -> (B T) C H W')

        # high-frequency branch
        # hf_out = self.hf_convin(high_freq)
        # hf_out = self.hf_body(hf_out)
        # hf_out = self.hf_convout(hf_out) + high_freq
        hf_out = self.safmn1(fea_decouple)
        hf_out = rearrange(hf_out, '(B T) C H W -> B C T H W', T=16)

        # low-frequency branch
        # lf_out = self.lf_convin(low_freq)
        # lf_out = self.lf_body(lf_out)
        # lf_out = self.lf_convout(lf_out) + low_freq
        lf_out = self.safmn2(fea_decouple)
        lf_out = rearrange(lf_out, '(B T) C H W -> B C T H W', T=16)

        return clean_image, hf_out, lf_out
    
    def temporal_forward(self, x):
        x = rearrange(x, "B (T S) C -> (B S) T C", T=16)
        tpe = self.get_temporal_pos_embed().to(x.device, x.dtype)
        x = x + tpe
        x = self.temporal_body(x)
        x = rearrange(x, "(B S) T C -> B (T S) C", S=256)
        return x