efficientnetv2_to_mmpretrain.py 3.85 KB
Newer Older
limm's avatar
limm committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
# Copyright (c) OpenMMLab. All rights reserved.
"""convert the weights of efficientnetv2 in
timm(https://github.com/rwightman/pytorch-image-models) to mmpretrain
format."""
import argparse
import os.path as osp

import mmengine
import torch
from mmengine.runner import CheckpointLoader


def convert_from_efficientnetv2_timm(param):
    # main change_key
    param_lst = list(param.keys())
    op = str(int(param_lst[-9][7]) + 2)
    new_key = dict()
    for name in param_lst:
        data = param[name]
        if 'blocks' not in name:
            if 'conv_stem' in name:
                name = name.replace('conv_stem', 'backbone.layers.0.conv')
            if 'bn1' in name:
                name = name.replace('bn1', 'backbone.layers.0.bn')
            if 'conv_head' in name:
                # if efficientnet-v2_s/base/b1/b2/b3,op = 7,
                # if for m/l/xl , op = 8
                name = name.replace('conv_head', f'backbone.layers.{op}.conv')
            if 'bn2' in name:
                name = name.replace('bn2', f'backbone.layers.{op}.bn')
            if 'classifier' in name:
                name = name.replace('classifier', 'head.fc')
        else:
            operator = int(name[7])
            if operator == 0:
                name = name[:7] + str(operator + 1) + name[8:]
                name = name.replace('blocks', 'backbone.layers')
                if 'conv' in name:
                    name = name.replace('conv', 'conv')
                if 'bn1' in name:
                    name = name.replace('bn1', 'bn')
            elif operator < 3:
                name = name[:7] + str(operator + 1) + name[8:]
                name = name.replace('blocks', 'backbone.layers')
                if 'conv_exp' in name:
                    name = name.replace('conv_exp', 'conv1.conv')
                if 'conv_pwl' in name:
                    name = name.replace('conv_pwl', 'conv2.conv')
                if 'bn1' in name:
                    name = name.replace('bn1', 'conv1.bn')
                if 'bn2' in name:
                    name = name.replace('bn2', 'conv2.bn')
            else:
                name = name[:7] + str(operator + 1) + name[8:]
                name = name.replace('blocks', 'backbone.layers')
                if 'conv_pwl' in name:
                    name = name.replace('conv_pwl', 'linear_conv.conv')
                if 'conv_pw' in name:
                    name = name.replace('conv_pw', 'expand_conv.conv')
                if 'conv_dw' in name:
                    name = name.replace('conv_dw', 'depthwise_conv.conv')
                if 'bn1' in name:
                    name = name.replace('bn1', 'expand_conv.bn')
                if 'bn2' in name:
                    name = name.replace('bn2', 'depthwise_conv.bn')
                if 'bn3' in name:
                    name = name.replace('bn3', 'linear_conv.bn')
                if 'se.conv_reduce' in name:
                    name = name.replace('se.conv_reduce', 'se.conv1.conv')
                if 'se.conv_expand' in name:
                    name = name.replace('se.conv_expand', 'se.conv2.conv')
        new_key[name] = data
    return new_key


def main():
    parser = argparse.ArgumentParser(
        description='Convert pretrained efficientnetv2 '
        'models in timm to mmpretrain style.')
    parser.add_argument('src', help='src model path or url')
    # The dst path must be a full path of the new checkpoint.
    parser.add_argument('dst', help='save path')
    args = parser.parse_args()

    checkpoint = CheckpointLoader.load_checkpoint(args.src, map_location='cpu')

    if 'state_dict' in checkpoint:
        state_dict = checkpoint['state_dict']
    else:
        state_dict = checkpoint

    weight = convert_from_efficientnetv2_timm(state_dict)
    mmengine.mkdir_or_exist(osp.dirname(args.dst))
    torch.save(weight, args.dst)

    print('Done!!')


if __name__ == '__main__':
    main()