lprnet.py 3.54 KB
Newer Older
liuhy's avatar
liuhy committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import torch
import torch.nn as nn
import torch.nn.functional as F

class small_basic_block(nn.Module):
    def __init__(self, ch_in, ch_out):
        super(small_basic_block, self).__init__()
        self.block = nn.Sequential(
            nn.Conv2d(ch_in, ch_out // 4, kernel_size=1),
            nn.ReLU(),
            nn.Conv2d(ch_out // 4, ch_out // 4, kernel_size=(3, 1), padding=(1, 0)),
            nn.ReLU(),
            nn.Conv2d(ch_out // 4, ch_out // 4, kernel_size=(1, 3), padding=(0, 1)),
            nn.ReLU(),
            nn.Conv2d(ch_out // 4, ch_out, kernel_size=1),
        )
    def forward(self, x):
        return self.block(x)

class my_lprnet(nn.Module):
    def __init__(self, class_num):
        super(my_lprnet, self).__init__()
        self.class_num = class_num
        self.stage1 = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, stride=1), # 0
            nn.BatchNorm2d(num_features=64),
            nn.ReLU())
        self.stage2 = nn.Sequential(small_basic_block(ch_in=64, ch_out=128),    # *** 4 ***
            nn.BatchNorm2d(num_features=128),
            nn.ReLU())
        self.stage3 = nn.Sequential(small_basic_block(ch_in=64, ch_out=256),   # 8
            nn.BatchNorm2d(num_features=256),
            nn.ReLU(),  # 10
            small_basic_block(ch_in=256, ch_out=256),   # *** 11 ***
            nn.BatchNorm2d(num_features=256),   # 12
            nn.ReLU())
        self.stage4 = nn.Sequential(nn.Conv2d(in_channels=64, out_channels=256, kernel_size=(1, 4), stride=1),  # 16
            nn.BatchNorm2d(num_features=256),
            nn.ReLU(),  # 18
            nn.Conv2d(in_channels=256, out_channels=class_num, kernel_size=(13, 1), stride=1), # 20
            nn.BatchNorm2d(num_features=class_num),
            nn.ReLU())
        self.container = nn.Conv2d(in_channels=448 + class_num, out_channels=class_num, kernel_size=(1, 1), stride=(1, 1))

    def forward(self, x):
        out1 = self.stage1(x)
        out = F.max_pool2d(out1, 3, stride=(1, 1))
        out2 = self.stage2(out)

        out = F.max_pool2d(out2, 3, stride=(1, 2))   ###这里可以改成F.max_pool3d,之所以这么写,是为了便于转换生成onnx文件
        out = F.max_pool2d(out.permute(0, 2, 3, 1).contiguous(), 1, stride=(1, 2))
        out = out.permute(0, 3, 1, 2).contiguous()

        out3 = self.stage3(out)

        out = F.max_pool2d(out3, 3, stride=(1, 2))
        out = F.max_pool2d(out.permute(0, 2, 3, 1).contiguous(), 1, stride=(1, 4))
        out = out.permute(0, 3, 1, 2).contiguous()
        out4 = self.stage4(out)

        out1 = F.avg_pool2d(out1, kernel_size=5, stride=5)
        f = torch.pow(out1, 2)
        f = torch.mean(f)
        out1 = torch.div(out1, f)

        out2 = F.avg_pool2d(out2, kernel_size=5, stride=5)
        f = torch.pow(out2, 2)
        f = torch.mean(f)
        out2 = torch.div(out2, f)

        out3 = F.avg_pool2d(out3, kernel_size=(4, 10), stride=(4, 2))
        f = torch.pow(out3, 2)
        f = torch.mean(f)
        out3 = torch.div(out3, f)

        f = torch.pow(out4, 2)
        f = torch.mean(f)
        out4 = torch.div(out4, f)

        logits = torch.cat((out1, out2, out3, out4), 1)
        logits = self.container(logits)
        logits = torch.mean(logits, dim=2)
        # logits = logits.view(self.class_num, -1)
        return logits

def build_lprnet(class_num, phase=False):

    Net = my_lprnet(class_num)
    if phase:
        return Net.train()
    else:
        return Net.eval()