Unverified Commit 0aea0a56 authored by Ningxin Zheng's avatar Ningxin Zheng Committed by GitHub
Browse files

Support Speedup for Slim Pruner. (#4008)

parent d8e56857
......@@ -184,14 +184,10 @@ class ChannelMaskConflict(MaskFix):
super(ChannelMaskConflict, self).__init__(
masks, model, dummy_input, traced)
self.conv_prune_dim = detect_mask_prune_dim(masks, model)
self.channel_prune_type = detect_channel_prune_type(masks, model)
_logger.info('Dectected conv prune dim" %d', self.conv_prune_dim)
def fix_mask(self):
"""
Fix the mask conflict before the mask inference for the layers that
has shape dependencies. This function should be called before the
mask inference of the 'speedup' module.
"""
"""
Fix the mask conflict before the mask inference for the layers that
has shape dependencies. This function should be called before the
......@@ -200,7 +196,8 @@ class ChannelMaskConflict(MaskFix):
"""
if self.conv_prune_dim == 0:
channel_depen = ChannelDependency(
self.model, self.dummy_input, self.traced)
self.model, self.dummy_input, self.traced, self.channel_prune_type)
else:
channel_depen = InputChannelDependency(
self.model, self.dummy_input, self.traced)
......@@ -307,10 +304,44 @@ class ChannelMaskConflict(MaskFix):
return self.masks
def detect_channel_prune_type(masks, model):
"""
User can prune a channel through two ways: 1) prune
the corresponding filter of the conv layer(all the
filter related pruner), 2) prune the BN layers that
followed after a conv(Slim pruner). This function find
the pruning type of the masks.
Parameters
----------
masks: dict
A dict object that stores the masks.
model: nn.Module
Model object which the mask can be applied on.
Returns:
-------
prune_type: str
Could be Filter or Batchnorm
"""
prune_type = 'Filter'
all_batch_norm = True
for layer_name in masks:
_, m = get_module_by_name(model, layer_name)
if m is None or (not isinstance(m, torch.nn.BatchNorm2d)):
all_batch_norm = False
break
if all_batch_norm:
# if all masks are for batchnorm layers, then the prune_type is BatchNorm
# Note, actually we currently do not support pruning both Conv and BatchNorm
# at the same time.
prune_type = 'Batchnorm'
return prune_type
def detect_mask_prune_dim(masks, model):
"""
Detect how the masks of convolutional layers are pruned.
Parameters
----------
masks: dict
......
......@@ -85,7 +85,7 @@ def reshape_break_channel_dependency(op_node):
class ChannelDependency(Dependency):
def __init__(self, model=None, dummy_input=None, traced_model=None):
def __init__(self, model=None, dummy_input=None, traced_model=None, prune_type='Filter'):
"""
This model analyze the channel dependencies between the conv
layers in a model.
......@@ -98,7 +98,18 @@ class ChannelDependency(Dependency):
traced_model : torch._C.Graph
if we alreay has the traced graph of the target model, we donnot
need to trace the model again.
"""
prune_type: str
This parameter indicates the channel pruning type: 1) `Filter`
prune the filter of the convolution layer to prune the corresponding
channels 2) `Batchnorm`: prune the channel in the batchnorm layer
"""
self.prune_type = prune_type
self.target_types = []
if self.prune_type == 'Filter':
self.target_types.extend(['Conv2d', 'Linear', 'ConvTranspose2d'])
elif self.prune_type == 'Batchnorm':
self.target_types.append('BatchNorm2d')
super(ChannelDependency, self).__init__(
model, dummy_input, traced_model)
......@@ -114,12 +125,13 @@ class ChannelDependency(Dependency):
parent_layers: list
nearest father conv/linear layers for the target worknode.
"""
parent_layers = []
queue = []
queue.append(node)
while queue:
curnode = queue.pop(0)
if curnode.op_type == 'Conv2d' or curnode.op_type == 'Linear' or curnode.op_type == 'ConvTranspose2d':
if curnode.op_type in self.target_types:
# find the first met conv
parent_layers.append(curnode.name)
continue
......@@ -130,6 +142,7 @@ class ChannelDependency(Dependency):
parents = [self.graph.name_to_node[name] for name in parents]
for parent in parents:
queue.append(parent)
return parent_layers
def build_dependency(self):
......@@ -193,7 +206,7 @@ class ChannelDependency(Dependency):
csv_w = csv.writer(csvf, delimiter=',')
csv_w.writerow(header)
for node in self.graph.nodes_py.nodes_op:
if node.op_type != 'Conv2d' or node in visited:
if node.op_type not in self.target_types or node in visited:
continue
setid += 1
row = ['Set %d' % setid]
......@@ -220,7 +233,7 @@ class ChannelDependency(Dependency):
d_sets = []
visited = set()
for node in self.graph.nodes_py.nodes_op:
if (node.op_type != 'Conv2d' and node.op_type != 'Linear') or node in visited:
if node.op_type not in self.target_types or node in visited:
continue
tmp_set = set()
if node.name not in self.dependency:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment