test_flop_count.py 1.19 KB
Newer Older
Yanghan Wang's avatar
Yanghan Wang committed
1
import os
2
import tempfile
Yanghan Wang's avatar
Yanghan Wang committed
3

4
from d2go.utils.flop_calculator import dump_flops_info
5
6
7
from d2go.utils.testing.data_loader_helper import (
    create_detection_data_loader_on_toy_dataset,
)
Yanghan Wang's avatar
Yanghan Wang committed
8
from d2go.utils.testing.rcnn_helper import RCNNBaseTestCases
9
10
11
12
13
14
15
16
17
18


class TestFlopCount(RCNNBaseTestCases.TemplateTestCase):
    def setup_custom_test(self):
        super().setup_custom_test()
        self.cfg.merge_from_file("detectron2go://mask_rcnn_fbnetv3a_dsmask_C4.yaml")

    def test_flop_count(self):
        size_divisibility = max(self.test_model.backbone.size_divisibility, 10)
        h, w = size_divisibility, size_divisibility * 2
19
20
21
        with create_detection_data_loader_on_toy_dataset(
            self.cfg, h, w, is_train=False
        ) as data_loader:
22
            inputs = (next(iter(data_loader)),)
23
24
25
26

        with tempfile.TemporaryDirectory(prefix="d2go_test") as output_dir:
            dump_flops_info(self.test_model, inputs, output_dir)

Yanghan Wang's avatar
Yanghan Wang committed
27
28
29
30
31
            for fname in [
                "flops_str_mobilecv",
                "flops_str_fvcore",
                "flops_table_fvcore",
            ]:
32
33
                outf = os.path.join(output_dir, fname + ".txt")
                self.assertTrue(os.path.isfile(outf))