convert_hourglass_weight.py 905 Bytes
Newer Older
chenych's avatar
chenych committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

MODEL_PATH = '../../models/ExtremeNet_500000.pkl'
OUT_PATH = '../../models/ExtremeNet_500000.pth'

import torch
state_dict = torch.load(MODEL_PATH)
key_map = {'t_heats': 'hm_t', 'l_heats': 'hm_l', 'b_heats': 'hm_b', \
           'r_heats': 'hm_r', 'ct_heats': 'hm_c', \
           't_regrs': 'reg_t', 'l_regrs': 'reg_l', \
           'b_regrs': 'reg_b', 'r_regrs': 'reg_r'}

out = {}
for k in state_dict.keys():
  changed = False
  for m in key_map.keys():
    if m in k:
      if 'ct_heats' in k and m == 't_heats':
        continue
      new_k = k.replace(m, key_map[m])
      out[new_k] = state_dict[k]
      changed = True
      print('replace {} to {}'.format(k, new_k))
  if not changed:
    out[k] = state_dict[k]
data = {'epoch': 0,
        'state_dict': out}
torch.save(data, OUT_PATH)