meta_arch_helper.py 2.48 KB
Newer Older
facebook-github-bot's avatar
facebook-github-bot committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved


import torch
from detectron2.modeling import META_ARCH_REGISTRY
from detectron2.structures import Boxes, ImageList, Instances

from d2go.tests.data_loader_helper import create_local_dataset


@META_ARCH_REGISTRY.register()
class DetMetaArchForTest(torch.nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.conv = torch.nn.Conv2d(3, 4, kernel_size=3, stride=1, padding=1)
        self.bn = torch.nn.BatchNorm2d(4)
        self.relu = torch.nn.ReLU(inplace=True)
        self.avgpool = torch.nn.AdaptiveAvgPool2d((1, 1))
        # weights that will be updated in forward() during training, use to simulate
        # weight udpates in optimization step
        self.register_buffer("scale_weight", torch.Tensor([0.0]))

    @property
    def device(self):
        return self.conv.weight.device

    def forward(self, inputs):
        if not self.training:
            return self.inference(inputs)

        images = [x["image"].to(self.device) for x in inputs]
        images = ImageList.from_tensors(images, 1)
        ret = self.conv(images.tensor)
        ret = self.bn(ret)
        ret = self.relu(ret)
        ret = self.avgpool(ret)

        # simulate weight updates
        self.scale_weight.fill_(1.0)

        return {"loss": ret.norm()}

    def inference(self, inputs):
        instance = Instances((10, 10))
        instance.pred_boxes = Boxes(
            torch.tensor([[2.5, 2.5, 7.5, 7.5]], device=self.device) * self.scale_weight
        )
        instance.scores = torch.tensor([0.9], device=self.device)
        instance.pred_classes = torch.tensor([1], dtype=torch.int32, device=self.device)
        ret = [{"instances": instance}]
        return ret


def get_det_meta_arch_cfg(cfg, dataset_name, output_dir):
    cfg.MODEL.DEVICE = "cpu"
    cfg.MODEL.META_ARCHITECTURE = "DetMetaArchForTest"

    cfg.DATASETS.TRAIN = (dataset_name,)
    cfg.DATASETS.TEST = (dataset_name,)

    cfg.INPUT.MIN_SIZE_TRAIN = (10,)
    cfg.INPUT.MIN_SIZE_TEST = (10,)

    cfg.SOLVER.MAX_ITER = 5
    cfg.SOLVER.STEPS = [2]
    cfg.SOLVER.WARMUP_ITERS = 1
    cfg.SOLVER.CHECKPOINT_PERIOD = 1
    cfg.SOLVER.IMS_PER_BATCH = 2
    cfg.SOLVER.REFERENCE_WORLD_SIZE = 0

    cfg.OUTPUT_DIR = output_dir

    return cfg


def create_detection_cfg(runner, output_dir):
    ds_name = create_local_dataset(output_dir, 5, 10, 10)
    cfg = runner.get_default_cfg()
    return get_det_meta_arch_cfg(cfg, ds_name, output_dir)