wavelet.py 1.78 KB
Newer Older
yangzhong's avatar
yangzhong committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import torch
import torch.nn as nn


def Normalize(x):
    ymax = 255
    ymin = 0
    xmax = x.max()
    xmin = x.min()
    return (ymax-ymin)*(x-xmin)/(xmax-xmin) + ymin


def dwt_init(x):

    x01 = x[:,:, :, 0::2, :] / 2
    x02 = x[:,:, :, 1::2, :] / 2
    x1 = x01[:,:, :, :, 0::2]
    x2 = x02[:,:, :, :, 0::2]
    x3 = x01[:,:, :, :, 1::2]
    x4 = x02[:,:, :, :, 1::2]
    x_LL = x1 + x2 + x3 + x4
    x_HL = -x1 - x2 + x3 + x4
    x_LH = -x1 + x2 - x3 + x4
    x_HH = x1 - x2 - x3 + x4

    return torch.cat((x_LL, x_HL, x_LH, x_HH), 1)


# 使用哈尔 haar 小波变换来实现二维离散小波
def iwt_init(x):
    r = 2
    T_time,in_batch, in_channel, in_height, in_width = x.size()
    out_time,out_batch, out_channel, out_height, out_width = T_time,int(in_batch/(r**2)),in_channel, r * in_height, r * in_width
    x1 = x[:,0:out_batch, :, :] / 2
    x2 = x[:,out_batch:out_batch * 2, :, :, :] / 2
    x3 = x[:,out_batch * 2:out_batch * 3, :, :, :] / 2
    x4 = x[:,out_batch * 3:out_batch * 4, :, :, :] / 2

    h = torch.zeros([out_time,out_batch, out_channel, out_height,
                     out_width]).float().to(x.device)

    h[:,:, :, 0::2, 0::2] = x1 - x2 - x3 + x4
    h[:,:, :, 1::2, 0::2] = x1 - x2 + x3 - x4
    h[:,:, :, 0::2, 1::2] = x1 + x2 - x3 - x4
    h[:,:, :, 1::2, 1::2] = x1 + x2 + x3 + x4

    return h


class DWT(nn.Module):
    def __init__(self):
        super(DWT, self).__init__()
        self.requires_grad = False  # 信号处理,非卷积运算,不需要进行梯度求导

    def forward(self, x):
        return dwt_init(x)


class IWT(nn.Module):
    def __init__(self):
        super(IWT, self).__init__()
        self.requires_grad = False

    def forward(self, x):
        return iwt_init(x)