efficientdet.py 2.19 KB
Newer Older
chenych's avatar
chenych committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import torch
import torch.nn as nn
import math
from .efficientnet import EfficientNet
from .bifpn import BIFPN
from .retinahead import RetinaHead
from torchvision.ops import nms 
import torch.nn.functional as F

MODEL_MAP = {
    'efficientdet-d0': 'efficientnet-b0',
    'efficientdet-d1': 'efficientnet-b1',
    'efficientdet-d2': 'efficientnet-b2',
    'efficientdet-d3': 'efficientnet-b3',
    'efficientdet-d4': 'efficientnet-b4',
    'efficientdet-d5': 'efficientnet-b5',
    'efficientdet-d6': 'efficientnet-b6',
    'efficientdet-d7': 'efficientnet-b6',
}
class EfficientDet(nn.Module):
    def __init__(self,
                 intermediate_channels,
                 network = 'efficientdet-d1',
                 D_bifpn=3,
                 W_bifpn=32,
                 D_class=3,
                 scale_ratios = [0.5, 1, 2, 4, 8,16,32],
                 ):
        super(EfficientDet, self).__init__()
        self.backbone = EfficientNet.from_pretrained(MODEL_MAP[network])
        self.neck = BIFPN(in_channels=self.backbone.get_list_features(),
                                out_channels=W_bifpn,
                                stack=D_bifpn,
                                num_outs=7)
        self.bbox_head = RetinaHead(num_classes = intermediate_channels,
                                    in_channels = W_bifpn)

        self.scale_ratios = scale_ratios
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
        self.freeze_bn()

    def forward(self, inputs):
        x = self.extract_feat(inputs)
        outs = self.bbox_head(x)

        return outs[0][1]
        
    def freeze_bn(self):
        '''Freeze BatchNorm layers.'''
        for layer in self.modules():
            if isinstance(layer, nn.BatchNorm2d):
                layer.eval()
    def extract_feat(self, img):
        """
            Directly extract features from the backbone+neck
        """
        x = self.backbone(img)
        x = self.neck(x)
        return x