test_detr_export.py 2.56 KB
Newer Older
Hang Zhang's avatar
Hang Zhang committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

import unittest

import torch
from d2go.runner import create_runner
from detr.util.misc import nested_tensor_from_tensor_list
from fvcore.nn import flop_count_table, FlopCountAnalysis


class Tester(unittest.TestCase):
    @staticmethod
    def _set_detr_cfg(cfg, enc_layers, dec_layers, num_queries, dim_feedforward):
        cfg.MODEL.META_ARCHITECTURE = "Detr"
        cfg.MODEL.DETR.NUM_OBJECT_QUERIES = num_queries
        cfg.MODEL.DETR.ENC_LAYERS = enc_layers
        cfg.MODEL.DETR.DEC_LAYERS = dec_layers
        cfg.MODEL.DETR.DEEP_SUPERVISION = False
        cfg.MODEL.DETR.DIM_FEEDFORWARD = dim_feedforward  # 2048

    def _assert_model_output(self, model, scripted_model):
        x = nested_tensor_from_tensor_list(
            [torch.rand(3, 200, 200), torch.rand(3, 200, 250)]
        )
        out = model(x)
        out_script = scripted_model(x)
        self.assertTrue(out["pred_logits"].equal(out_script["pred_logits"]))
        self.assertTrue(out["pred_boxes"].equal(out_script["pred_boxes"]))

    def test_detr_res50_export(self):
        runner = create_runner("d2go.projects.detr.runner.DETRRunner")
        cfg = runner.get_default_cfg()
        cfg.MODEL.DEVICE = "cpu"
        # DETR
        self._set_detr_cfg(cfg, 6, 6, 100, 2048)
        # backbone
        cfg.MODEL.BACKBONE.NAME = "build_resnet_backbone"
        cfg.MODEL.RESNETS.DEPTH = 50
        cfg.MODEL.RESNETS.STRIDE_IN_1X1 = False
        cfg.MODEL.RESNETS.OUT_FEATURES = ["res2", "res3", "res4", "res5"]
        # build model
        model = runner.build_model(cfg).eval()
        model = model.detr
        scripted_model = torch.jit.script(model)
        self._assert_model_output(model, scripted_model)

    def test_detr_fbnet_export(self):
        runner = create_runner("d2go.projects.detr.runner.DETRRunner")
        cfg = runner.get_default_cfg()
        cfg.MODEL.DEVICE = "cpu"
        # DETR
        self._set_detr_cfg(cfg, 3, 3, 50, 256)
        # backbone
        cfg.MODEL.BACKBONE.NAME = "FBNetV2C4Backbone"
        cfg.MODEL.FBNET_V2.ARCH = "FBNetV3_A_dsmask_C5"
        cfg.MODEL.FBNET_V2.WIDTH_DIVISOR = 8
        cfg.MODEL.FBNET_V2.OUT_FEATURES = ["trunk4"]
        # build model
        model = runner.build_model(cfg).eval()
        model = model.detr
        print(model)
        scripted_model = torch.jit.script(model)
        self._assert_model_output(model, scripted_model)
        # print flops
        table = flop_count_table(FlopCountAnalysis(model, ([torch.rand(3, 224, 320)],)))
        print(table)