step2_rename_huge.py 2.1 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import torch
import argparse
import math
from collections import OrderedDict

parser = argparse.ArgumentParser(description='Hyperparams')
parser.add_argument('filename', nargs='?', type=str, default=None)
args = parser.parse_args()


def gen_grid(n_heads):
    n_heads = n_heads
    n_points = 9
    points_list = []
    kernel_size = int(math.sqrt(n_points))
    y, x = torch.meshgrid(
        torch.linspace((-kernel_size // 2 + 1), (kernel_size // 2),
                       kernel_size,
                       dtype=torch.float32),
        torch.linspace((-kernel_size // 2 + 1), (kernel_size // 2),
                       kernel_size,
                       dtype=torch.float32))
    points_list.extend([y, x])
    grid = torch.stack(points_list, -1).reshape(-1, 1, 2).\
        repeat(1, n_heads, 1).permute(1, 0, 2)

    return grid


def convert_to_newop(m):
    new_sd = OrderedDict()
    n_points = 9
    for k, v in m.items():
        new_k = k
        if 'attn' in k:
            new_k = new_k.replace('attn', 'dcn')
            if 'sampling_offsets' in k:
                new_k = new_k.replace('sampling_offsets', 'offset')
            if 'attention_weights' in k:
                new_k = new_k.replace('attention_weights', 'mask')
            if 'value_proj' in k:
                new_k = new_k.replace('value_proj', 'input_proj')
        if 'ema' in k:
            continue
        if ".norm1_k." in k:
            new_k = new_k.replace('.norm1_k.', '.norm1_k.0.')
        if ".norm1_q." in k:
            new_k = new_k.replace('.norm1_q.', '.norm1_q.0.')
        if ".norm1_v." in k:
            new_k = new_k.replace('.norm1_v.', '.norm1_v.0.')
        if ".post_norms." in k:
            new_k = new_k.replace('.bias', '.0.bias')
            new_k = new_k.replace('.weight', '.0.weight')
        if "fc_norm." in k:
            new_k = new_k.replace('fc_norm.', 'fc_norm.0.')

        new_sd[new_k] = v.half()

    return new_sd


model = torch.load(args.filename, map_location=torch.device('cpu'))['state_dict']
new_model = {"state_dict": convert_to_newop(model)}
torch.save(new_model, args.filename.replace(".pth", "_rename.pth"))