flops.py 613 Bytes
Newer Older
chenzk's avatar
v1.0  
chenzk committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
import torch
import time
from timm import create_model
import model
import utils
from fvcore.nn import FlopCountAnalysis

T0 = 5
T1 = 10

for n, batch_size, resolution in [
    ('repvit_m0_9', 1024, 224),
]:
    inputs = torch.randn(1, 3, resolution,
                            resolution)
    model = create_model(n, num_classes=1000)
    utils.replace_batchnorm(model)
    n_parameters = sum(p.numel()
                       for p in model.parameters() if p.requires_grad)
    print('number of params:', n_parameters / 1e6)
    flops = FlopCountAnalysis(model, inputs)
    print("flops: ", flops.total() / 1e9)