__init__.py 1.2 KB
Newer Older
J-shang's avatar
J-shang committed
1
2
3
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

J-shang's avatar
J-shang committed
4
5
from .counter import count_flops_params
from .mask_conflict import ChannelMaskConflict, GroupMaskConflict
6
from .utils import *
J-shang's avatar
J-shang committed
7
from .sensitivity_analysis import SensitivityAnalysis
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
from .shape_dependency import *

def not_safe_to_prune(model, dummy_input):
    """
    Get the layers that are not safe to prune(may bring the shape conflict).
    For example, if the output tensor of a conv layer is directly followed by
    a shape-dependent function(such as reshape/view), then this conv layer
    may be not safe to be pruned. Pruning may change the output shape of
    this conv layer and result in shape problems. This function find all the
    layers that directly followed by the shape-dependent functions(view, reshape, etc).
    If you run the inference after the speedup and run into a shape related error,
    please exclude the layers returned by this function and try again.

    Parameters
    ----------
    model: torch.nn.Module
        The target model to prune.
    dummy_input: torch.Tensor/list of torch.Tensor/tuple of Tensor
    """
    reshape_dset = ReshapeDependency(model, dummy_input)
    return reshape_dset.dependency_sets