module.py 1.67 KB
Newer Older
Casper Hansen's avatar
Casper Hansen committed
1
import torch.nn as nn
Ji Lin's avatar
Ji Lin committed
2

Casper's avatar
Casper committed
3

Casper Hansen's avatar
Casper Hansen committed
4
5
def get_named_linears(module):
    return {name: m for name, m in module.named_modules() if isinstance(m, nn.Linear)}
Ji Lin's avatar
Ji Lin committed
6

Casper's avatar
Casper committed
7

Ji Lin's avatar
Ji Lin committed
8
9
10
11
12
13
14
15
def get_op_by_name(module, op_name):
    # get the op by its name relative to the module
    for name, m in module.named_modules():
        if name == op_name:
            return m
    raise ValueError(f"Cannot find op {op_name} in module {module}")


16
def set_op_by_name(layer, name, new_module):
Casper's avatar
Casper committed
17
    levels = name.split(".")
18
19
    if len(levels) > 1:
        mod_ = layer
Casper's avatar
Casper committed
20
        for l_idx in range(len(levels) - 1):
21
22
23
24
25
26
27
28
29
            if levels[l_idx].isdigit():
                mod_ = mod_[int(levels[l_idx])]
            else:
                mod_ = getattr(mod_, levels[l_idx])
        setattr(mod_, levels[-1], new_module)
    else:
        setattr(layer, name, new_module)


Ji Lin's avatar
Ji Lin committed
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
def get_op_name(module, op):
    # get the name of the op relative to the module
    for name, m in module.named_modules():
        if m is op:
            return name
    raise ValueError(f"Cannot find op {op} in module {module}")


def append_str_prefix(x, prefix):
    if isinstance(x, str):
        return prefix + x
    elif isinstance(x, tuple):
        return tuple([append_str_prefix(y, prefix) for y in x])
    elif isinstance(x, list):
        return [append_str_prefix(y, prefix) for y in x]
    else:
46
47
        return x

Casper's avatar
Casper committed
48

49
50
51
52
53
54
55
56
def exclude_layers_to_not_quantize(linear_layers, modules_to_not_convert):
    if modules_to_not_convert is None:
        return linear_layers

    filtered_layers = {}
    for name, linear_layer in linear_layers.items():
        if not any(key in name for key in modules_to_not_convert):
            filtered_layers[name] = linear_layer
Casper's avatar
Casper committed
57
    return filtered_layers