inject_nn.py 633 Bytes
Newer Older
1
2
3
4
import inspect

import torch.nn as nn

5
from nni.retiarii import basic_unit
6

7
8
9
10
11
_trace_module_names = [
    module_name for module_name in dir(nn)
    if module_name not in ['Module', 'ModuleList', 'ModuleDict', 'Sequential'] and
    inspect.isclass(getattr(nn, module_name)) and issubclass(getattr(nn, module_name), nn.Module)
]
12
13
14


def remove_inject_pytorch_nn():
15
16
17
    for name in _trace_module_names:
        if hasattr(getattr(nn, name), '__wrapped__'):
            setattr(nn, name, getattr(nn, name).__wrapped__)
18
19
20


def inject_pytorch_nn():
21
22
    for name in _trace_module_names:
        setattr(nn, name, basic_unit(getattr(nn, name)))