swin_transformer.py 648 Bytes
Newer Older
HHL's avatar
v  
HHL committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
import torch
from torch import nn
from mmdet.models import build_backbone, build_neck


class VisionBackbone(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.backbone = build_backbone(config.backbone_cfg)
        self.neck = build_neck(config.neck_cfg)
        self.freeze = config.vision_freeze

    def forward(self, img):
        """Directly extract features from the backbone+neck."""
        if self.freeze:
            with torch.no_grad():
                x = self.backbone(img)
                x = self.neck(x)
        else:
            x = self.backbone(img)
            x = self.neck(x)
        return x