flops_table.py 2.69 KB
Newer Older
qianyj's avatar
qianyj committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# Written by Hao Du and Houwen Peng
# email: haodu8-c@my.cityu.edu.hk and houwen.peng@microsoft.com

import torch

from ptflops import get_model_complexity_info


class FlopsEst(object):
    def __init__(self, model, input_shape=(2, 3, 224, 224), device='cpu'):
        self.block_num = len(model.blocks)
        self.choice_num = len(model.blocks[0])
        self.flops_dict = {}
        self.params_dict = {}

        if device == 'cpu':
            model = model.cpu()
        else:
            model = model.cuda()

        self.params_fixed = 0
        self.flops_fixed = 0

        input = torch.randn(input_shape)

        flops, params = get_model_complexity_info(
            model.conv_stem, (3, 224, 224), as_strings=False, print_per_layer_stat=False)
        self.params_fixed += params / 1e6
        self.flops_fixed += flops / 1e6

        input = model.conv_stem(input)

        for block_id, block in enumerate(model.blocks):
            self.flops_dict[block_id] = {}
            self.params_dict[block_id] = {}
            for module_id, module in enumerate(block):
                flops, params = get_model_complexity_info(module, tuple(
                    input.shape[1:]), as_strings=False, print_per_layer_stat=False)
                # Flops(M)
                self.flops_dict[block_id][module_id] = flops / 1e6
                # Params(M)
                self.params_dict[block_id][module_id] = params / 1e6

            input = module(input)

        # conv_last
        flops, params = get_model_complexity_info(model.global_pool, tuple(
            input.shape[1:]), as_strings=False, print_per_layer_stat=False)
        self.params_fixed += params / 1e6
        self.flops_fixed += flops / 1e6

        input = model.global_pool(input)

        # globalpool
        flops, params = get_model_complexity_info(model.conv_head, tuple(
            input.shape[1:]), as_strings=False, print_per_layer_stat=False)
        self.params_fixed += params / 1e6
        self.flops_fixed += flops / 1e6

    # return params (M)
    def get_params(self, arch):
        params = 0
        for block_id, block in enumerate(arch):
            if block == -1:
                continue
            params += self.params_dict[block_id][block]
        return params + self.params_fixed

    # return flops (M)
    def get_flops(self, arch):
        flops = 0
        for block_id, block in enumerate(arch):
            if block == 'LayerChoice1' or block_id == 'LayerChoice23':
                continue
            for idx, choice in enumerate(arch[block]):
                flops += self.flops_dict[block_id][idx] * (1 if choice else 0)
        return flops + self.flops_fixed