"official/utils/logs/benchmark_uploader.py" did not exist on "932364b62091aade23a586abdae989290be7fe72"
collect.py 3.09 KB
Newer Older
1
# Copyright (c) OpenMMLab. All rights reserved.
pppppM's avatar
pppppM committed
2
3
4
from typing import Dict, List, Tuple, Union

from mmengine.config.lazy import LazyAttr
5
6
7
from torch import nn


pppppM's avatar
pppppM committed
8
9
10
11
12
def collect_target_modules(model: nn.Module,
                           target: Union[str, type],
                           skip_names: List[str] = [],
                           prefix: str = '') -> Dict[str, nn.Module]:
    """Collects the specific target modules from the model.
13
14

    Args:
pppppM's avatar
pppppM committed
15
16
17
18
19
        model : The PyTorch module from which to collect the target modules.
        target : The specific target to be collected. It can be a class of a
            module or the name of a module.
        skip_names : List of names of modules to be skipped during collection.
        prefix : A string to be added as a prefix to the module names.
20
21

    Returns:
pppppM's avatar
pppppM committed
22
        A dictionary mapping from module names to module instances.
23
24
    """

pppppM's avatar
pppppM committed
25
26
    if isinstance(target, LazyAttr):
        target = target.build()
27

pppppM's avatar
pppppM committed
28
29
30
    if not isinstance(target, (type, str)):
        raise TypeError('Target must be a string (name of the module) '
                        'or a type (class of the module)')
31

pppppM's avatar
pppppM committed
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
    def _is_target(n, m):
        if isinstance(target, str):
            return target == type(m).__name__ and n not in skip_names
        return isinstance(m, target) and n not in skip_names

    name2mod = {}
    for name, mod in model.named_modules():
        m_name = f'{prefix}.{name}' if prefix else name
        if _is_target(name, mod):
            name2mod[m_name] = mod
    return name2mod


def collect_target_weights(model: nn.Module, target: Union[str, type],
                           skip_names: List[str]) -> Dict[str, nn.Module]:
    """Collects weights of the specific target modules from the model.

    Args:
        model : The PyTorch module from which to collect the weights of
            target modules.
        target : The specific target whose weights to be collected. It can be
            a class of a module or the name of a module.
        skip_names : Names of modules to be skipped during weight collection.

    Returns:
        A dictionary mapping from module instances to their
            corresponding weights.
    """

    named_modules = collect_target_modules(model, target, skip_names)
    mod2weight = {}
    for _, mod in named_modules.items():
        assert hasattr(
            mod, 'weight'), "The module does not have a 'weight' attribute"
        mod2weight[mod] = mod.weight
    return mod2weight


def bimap_name_mod(
    name2mod_mappings: List[Dict[str, nn.Module]]
) -> Tuple[Dict[str, nn.Module], Dict[nn.Module, str]]:
    """Generates bidirectional maps from module names to module instances and
    vice versa.
75
76

    Args:
pppppM's avatar
pppppM committed
77
78
        name2mod_mappings : List of dictionaries each mapping from module
            names to module instances.
79
80

    Returns:
pppppM's avatar
pppppM committed
81
82
        Two dictionaries providing bidirectional mappings between module
            names and module instances.
83
84
    """

pppppM's avatar
pppppM committed
85
86
87
88
89
90
    name2mod = {}
    mod2name = {}
    for mapping in name2mod_mappings:
        mod2name.update({v: k for k, v in mapping.items()})
        name2mod.update(mapping)
    return name2mod, mod2name