test_flop_count.py 1.13 KB
Newer Older
Yanghan Wang's avatar
Yanghan Wang committed
1
import os
2
import tempfile
Yanghan Wang's avatar
Yanghan Wang committed
3

4
5
from d2go.utils.flop_calculator import dump_flops_info
from d2go.utils.testing.data_loader_helper import create_fake_detection_data_loader
Yanghan Wang's avatar
Yanghan Wang committed
6
from d2go.utils.testing.rcnn_helper import RCNNBaseTestCases
7
8
9
10
11
12
13
14
15
16
17


class TestFlopCount(RCNNBaseTestCases.TemplateTestCase):
    def setup_custom_test(self):
        super().setup_custom_test()
        self.cfg.merge_from_file("detectron2go://mask_rcnn_fbnetv3a_dsmask_C4.yaml")

    def test_flop_count(self):
        size_divisibility = max(self.test_model.backbone.size_divisibility, 10)
        h, w = size_divisibility, size_divisibility * 2
        with create_fake_detection_data_loader(h, w, is_train=False) as data_loader:
18
            inputs = (next(iter(data_loader)),)
19
20
21
22

        with tempfile.TemporaryDirectory(prefix="d2go_test") as output_dir:
            dump_flops_info(self.test_model, inputs, output_dir)

Yanghan Wang's avatar
Yanghan Wang committed
23
24
25
26
27
            for fname in [
                "flops_str_mobilecv",
                "flops_str_fvcore",
                "flops_table_fvcore",
            ]:
28
29
                outf = os.path.join(output_dir, fname + ".txt")
                self.assertTrue(os.path.isfile(outf))