"lightx2v/vscode:/vscode.git/clone" did not exist on "08a3181b37ae530fedb2f4f54775d62cd173376c"
collect.py 1.7 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
# Copyright (c) OpenMMLab. All rights reserved.
from torch import nn


def collect_target_weights(model: nn.Module, target_module_types: type,
                           skip_modules: list) -> dict:
    """Collects target weight tensors in the model and returns them in a
    dictionary.

    Args:
        model (nn.Module): Model containing the target modules.
        target (type): Target module type, e.g., nn.Linear.
        skip_modules (list): List of modules that should not be included in
            the result.

    Returns:
        dict: A dictionary containing the target weight tensors in the model.
    """
    target_weights = {}
    for name, module in model.named_modules():
        if isinstance(module,
                      target_module_types) and name not in skip_modules:
            assert hasattr(module, 'weight')
            target_weights[name] = module.weight

    return target_weights


def collect_target_modules(model: nn.Module,
                           target_module_types: type,
                           skip_modules: list = []) -> dict:
    """Collects target weight tensors in the model and returns them in a
    dictionary.

    Args:
        model (nn.Module): Model containing the target modules.
        target (type): Target module type, e.g., nn.Linear.
        skip_modules (list): List of modules that should not be included in
            the result.

    Returns:
        dict: A dictionary containing the target modules in the model.
    """
    target_modules = {}
    for name, module in model.named_modules():
        if isinstance(module,
                      target_module_types) and name not in skip_modules:
            target_modules[name] = module

    return target_modules