"...git@developer.sourcefind.cn:OpenDAS/nni.git" did not exist on "8fb8f8b3cef0757401c1a7ed0fb1c8e3f659c5c4"
Unverified Commit 0a8fbbed authored by Ningxin Zheng's avatar Ningxin Zheng Committed by GitHub
Browse files

Loose the group dependency (#4128)

parent d204d8bf
......@@ -44,9 +44,8 @@ class DependencyAwarePruner(Pruner):
self._unwrap_model()
self.graph = TorchModuleGraph(model, dummy_input)
self._wrap_model()
self.channel_depen = ChannelDependency(
traced_model=self.graph.trace)
self.group_depen = GroupDependency(traced_model=self.graph.trace)
self.channel_depen = ChannelDependency(model, dummy_input, traced_model=self.graph.trace)
self.group_depen = GroupDependency(model, dummy_input, traced_model=self.graph.trace)
self.channel_depen = self.channel_depen.dependency_sets
self.channel_depen = {
name: sets for sets in self.channel_depen for name in sets}
......
from .utils import *
\ No newline at end of file
from .utils import *
from .shape_dependency import *
from .shape_dependency import ReshapeDependency
def not_safe_to_prune(model, dummy_input):
"""
Get the layers that are not safe to prune(may bring the shape conflict).
For example, if the output tensor of a conv layer is directly followed by
a shape-dependent function(such as reshape/view), then this conv layer
may be not safe to be pruned. Pruning may change the output shape of
this conv layer and result in shape problems. This function find all the
layers that directly followed by the shape-dependent functions(view, reshape, etc).
If you run the inference after the speedup and run into a shape related error,
please exclude the layers returned by this function and try again.
Parameters
----------
model: torch.nn.Module
The target model to prune.
dummy_input: torch.Tensor/list of torch.Tensor/tuple of Tensor
"""
reshape_dset = ReshapeDependency(model, dummy_input)
return reshape_dset.dependency_sets
......@@ -10,7 +10,7 @@ from .utils import get_module_by_name
_logger = logging.getLogger('FixMaskConflict')
def fix_mask_conflict(masks, model=None, dummy_input=None, traced=None):
def fix_mask_conflict(masks, model, dummy_input, traced=None):
"""
MaskConflict fix the mask conflict for the channel dependencies
and group dependency.
......@@ -81,7 +81,7 @@ class MaskFix:
class GroupMaskConflict(MaskFix):
def __init__(self, masks, model=None, dummy_input=None, traced=None):
def __init__(self, masks, model, dummy_input, traced=None):
"""
GroupMaskConflict fix the mask conflict between the layers that
has group dependecy with each other.
......@@ -168,7 +168,7 @@ class GroupMaskConflict(MaskFix):
class ChannelMaskConflict(MaskFix):
def __init__(self, masks, model=None, dummy_input=None, traced=None):
def __init__(self, masks, model, dummy_input, traced=None):
"""
ChannelMaskConflict fix the mask conflict between the layers that
has channel dependecy with each other.
......
......@@ -3,10 +3,14 @@
import csv
import logging
import torch
import numpy as np
from nni.compression.pytorch.compressor import PrunerModuleWrapper
from .utils import get_module_by_name
__all__ = ['ChannelDependency', 'GroupDependency', 'InputChannelDependency', 'AttentionWeightDependency']
__all__ = ['ChannelDependency', 'GroupDependency',
'InputChannelDependency', 'AttentionWeightDependency']
CONV_TYPE = 'aten::_convolution'
......@@ -45,6 +49,7 @@ class Dependency:
# the model or a already traced model
assert model is not None and dummy_input is not None
self.graph = TorchModuleGraph(model, dummy_input, traced_model)
self.model = model
self.dependency = dict()
self.build_dependency()
......@@ -85,7 +90,7 @@ def reshape_break_channel_dependency(op_node):
class ChannelDependency(Dependency):
def __init__(self, model=None, dummy_input=None, traced_model=None, prune_type='Filter'):
def __init__(self, model, dummy_input, traced_model=None, prune_type='Filter'):
"""
This model analyze the channel dependencies between the conv
layers in a model.
......@@ -261,7 +266,7 @@ class InputChannelDependency(ChannelDependency):
If not, the input channel dependency will be passed to the following nodes.
"""
def __init__(self, model, dummy_input=None, traced_model=None):
def __init__(self, model, dummy_input, traced_model=None):
"""
This model analyze the input channel dependencies between the conv
layers in a model.
......@@ -323,7 +328,7 @@ class InputChannelDependency(ChannelDependency):
class GroupDependency(Dependency):
def __init__(self, model=None, dummy_input=None, traced_model=None):
def __init__(self, model, dummy_input, traced_model=None):
"""
This model analyze the group dependencis between the conv
layers in a model.
......@@ -383,13 +388,17 @@ class GroupDependency(Dependency):
group : int
the number of the groups of the target conv layer.
"""
cpp_conv = list(filter(lambda x: x.kind() ==
CONV_TYPE, node_group.node_cpps))
assert len(cpp_conv) == 1
cpp_conv = cpp_conv[0]
inputs = list(cpp_conv.inputs())
# get the number of the group from the input parameters
group = inputs[8].toIValue()
node_name = node_group.name
_, leaf_module = get_module_by_name(self.model, node_name)
if isinstance(leaf_module, PrunerModuleWrapper):
leaf_module = leaf_module.module
assert isinstance(
leaf_module, (torch.nn.Conv2d, torch.nn.ConvTranspose2d))
group = leaf_module.groups
n_filter = leaf_module.out_channels
if n_filter == group:
# depthwise conv will not introduce extra group dependency
return 1
return group
def build_dependency(self):
......@@ -712,4 +721,3 @@ class AttentionWeightDependency(Dependency):
group = self.dependency[name]
if len(group) > 0:
csv_w.writerow([name, group])
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import torch
from .shape_dependency import ReshapeDependency
torch_float_dtype = [torch.float, torch.float16, torch.float32, torch.float64, torch.half, torch.double]
torch_integer_dtype = [torch.uint8, torch.int16, torch.short, torch.int32, torch.long, torch.bool]
......@@ -67,23 +67,3 @@ def randomize_tensor(tensor, start=1, end=100):
# with nn.init.uniform_
torch.nn.init.uniform_(tensor.data, start, end)
def not_safe_to_prune(model, dummy_input):
"""
Get the layers that are not safe to prune(may bring the shape conflict).
For example, if the output tensor of a conv layer is directly followed by
a shape-dependent function(such as reshape/view), then this conv layer
may be not safe to be pruned. Pruning may change the output shape of
this conv layer and result in shape problems. This function find all the
layers that directly followed by the shape-dependent functions(view, reshape, etc).
If you run the inference after the speedup and run into a shape related error,
please exclude the layers returned by this function and try again.
Parameters
----------
model: torch.nn.Module
The target model to prune.
dummy_input: torch.Tensor/list of torch.Tensor/tuple of Tensor
"""
reshape_dset = ReshapeDependency(model, dummy_input)
return reshape_dset.dependency_sets
\ No newline at end of file
......@@ -174,10 +174,18 @@ def prune_model_l1(model):
def generate_random_sparsity(model):
_start = 0.5
_end = 0.99
if isinstance(model, models.mobilenet.MobileNetV2):
# mobilenet models have great propagation characteristics
# so we use smaller sparsity ratio to avoid pruning the whole
# layer out
_start = 0.01
_end = 0.3
cfg_list = []
for name, module in model.named_modules():
if isinstance(module, nn.Conv2d):
sparsity = np.random.uniform(0.5, 0.99)
sparsity = np.random.uniform(_start, _end)
cfg_list.append({'op_types': ['Conv2d'], 'op_names': [name],
'sparsity': sparsity})
return cfg_list
......@@ -187,11 +195,19 @@ def generate_random_sparsity_v2(model):
"""
Only select 50% layers to prune.
"""
_start = 0.5
_end = 0.99
if isinstance(model, models.mobilenet.MobileNetV2):
# mobilenet models have great propagation characteristics
# so we use smaller sparsity ratio to avoid pruning the whole
# layer out
_start = 0.01
_end = 0.3
cfg_list = []
for name, module in model.named_modules():
if isinstance(module, nn.Conv2d):
if np.random.uniform(0, 1.0) > 0.5:
sparsity = np.random.uniform(0.5, 0.99)
sparsity = np.random.uniform(_start, _end)
cfg_list.append({'op_types': ['Conv2d'], 'op_names': [name],
'sparsity': sparsity})
return cfg_list
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment