"...git@developer.sourcefind.cn:OpenDAS/torchaudio.git" did not exist on "bf07ea6baf7a61ca0979715a5e27b26b6272c7d2"
Unverified Commit 4cf68009 authored by Xing's avatar Xing Committed by GitHub
Browse files

Add Group Norm support for Pruning model (#5069)

parent 858daf9f
......@@ -6,6 +6,7 @@ import logging
import queue
import re
from collections import defaultdict
from typing import List, Dict
import torch
from torch.utils.tensorboard._pytorch_graph import NodePy, NodePyIO, NodePyOP, GraphPy
CLASSTYPE_KIND = 'ClassType'
......@@ -262,6 +263,7 @@ class TorchModuleGraph(TorchGraph):
def __init__(self, model=None, dummy_input=None, traced_model=None):
super().__init__(model, dummy_input, traced_model)
self.name_to_node: Dict[str, NodePyOP]
self.global_count = 0
self.reused_module = set()
self.name_to_node, self.input_to_node, self.output_to_node = self._build_graph()
......@@ -802,7 +804,7 @@ class TorchModuleGraph(TorchGraph):
node_group.auxiliary = self._extract_cat_info(
node_group, cpp_node)
def find_predecessors(self, unique_name):
def find_predecessors(self, unique_name) -> List[str]:
"""
Find predecessor node of the given node
......@@ -825,7 +827,7 @@ class TorchModuleGraph(TorchGraph):
predecessors.append(node_py.unique_name)
return predecessors
def find_successors(self, unique_name):
def find_successors(self, unique_name) -> List[str]:
"""
Find successor nodes of the given node
......
......@@ -48,7 +48,8 @@ replace_module = {
'ConvTranspose2d': lambda module, masks: replace_convtranspose2d(module, masks),
'Embedding': lambda module, masks: replace_embedding(module, masks),
'PixelShuffle': lambda module, masks: replace_pixelshuffle(module, masks),
'Flatten': lambda module, masks: no_replace(module, masks)
'Flatten': lambda module, masks: no_replace(module, masks),
'GroupNorm': lambda module, masks: replace_groupnorm(module, masks),
}
......@@ -310,6 +311,90 @@ def replace_batchnorm2d(norm, masks):
return new_norm
def replace_groupnorm(norm: nn.GroupNorm, masks):
"""
Parameters
----------
norm : torch.nn.GroupNorm
The group norm module to be replace
masks : Tuple of the input masks, output masks and weight masks
Tuple of the masks, for example
([input_m1, input_m2], [output_m], {'weight':weight_m})
Returns
-------
torch.nn.GroupNorm
The new group norm module
"""
in_masks, output_mask, _ = masks
assert isinstance(norm, nn.GroupNorm)
in_mask = in_masks[0]
# N, C, H, W
_, remained_in = convert_to_coarse_mask(in_mask, 1)
_, remained_out = convert_to_coarse_mask(output_mask, 1)
assert len(remained_in.size()) == 1
if remained_in.size(0) != remained_out.size(0):
raise ShapeMisMatchError()
ori_channel_step = norm.num_channels // norm.num_groups
for groupid in range(norm.num_groups):
in_start = groupid * ori_channel_step
in_end = in_start + ori_channel_step
new_channel_step = torch.logical_and(
in_start <= remained_in,
remained_in < in_end,
).sum().item()
# this group fully pruned
if new_channel_step == 0:
continue
break
new_groups = 0
# Validate
for groupid in range(norm.num_groups):
in_start = groupid * ori_channel_step
in_end = in_start + ori_channel_step
num_item = torch.logical_and(
in_start <= remained_in,
remained_in < in_end,
).sum().item()
if num_item == 0:
continue
# check if the number of remained channel of each group are the same
if num_item != new_channel_step:
raise UnBalancedGroupError()
new_groups += 1
new_num_channels = remained_in.size()[0]
new_module = nn.GroupNorm(
new_groups,
new_num_channels,
eps=norm.eps,
affine=norm.affine,
)
if new_module.affine:
new_module.weight.data = torch.index_select(
norm.weight.data,
0,
remained_in,
)
new_module.bias.data = torch.index_select(
norm.bias.data,
0,
remained_in,
)
return new_module
def replace_instancenorm2d(norm, masks):
"""
Parameters
......@@ -409,18 +494,20 @@ def replace_conv2d(conv, masks):
in_end = in_start + ori_inchannel_step
out_start = groupid * ori_outchannel_step
out_end = out_start + ori_outchannel_step
current_input_index = list(
filter(lambda x: in_start <= x and x < in_end, remained_in.tolist()))
current_output_index = list(
filter(lambda x: out_start <= x and x < out_end, remained_out.tolist()))
new_inchannel_step: int = torch.logical_and(
in_start <= remained_in,
remained_in < in_end
).sum().item()
new_outchannel_step: int = torch.logical_and(
out_start <= remained_out,
remained_out < out_end
).sum().item()
# remap the global index to the group index
if len(current_input_index) == 0:
if new_inchannel_step == 0:
# if the whole group are pruned
continue
else:
new_inchannel_step = len(current_input_index)
new_outchannel_step = len(current_output_index)
break
tmp_weight = torch.ones(
n_remained_out, new_inchannel_step, k_size1, k_size2)
......@@ -436,12 +523,15 @@ def replace_conv2d(conv, masks):
in_end = in_start + ori_inchannel_step
out_start = groupid * ori_outchannel_step
out_end = out_start + ori_outchannel_step
current_input_index = list(
filter(lambda x: in_start <= x and x < in_end, remained_in.tolist()))
current_output_index = list(
filter(lambda x: out_start <= x and x < out_end, remained_out.tolist()))
current_input_mask = torch.logical_and(in_start <= remained_in, remained_in < in_end)
current_input_index = remained_in[current_input_mask]
current_output_mask = torch.logical_and(out_start <= remained_out, remained_out < out_end)
current_output_index = remained_out[current_output_mask]
# remap the global index to the group index
current_input_index = [x-in_start for x in current_input_index]
current_input_index = current_input_index - in_start
if len(current_input_index) == 0:
# if the whole group are pruned
assert len(current_output_index) == 0
......
......@@ -82,6 +82,8 @@ class AutoMaskInference:
if output_mask is not None:
# assume the given output mask is right
self.output_mask = output_mask
elif isinstance(module, nn.GroupNorm):
self.output_mask = self.in_masks[0]
else:
if isinstance(self.output, torch.Tensor):
self.output_mask = torch.ones_like(self.output)
......
......@@ -49,7 +49,7 @@ class Dependency:
# user should provide model & dummy_input to trace
# the model or a already traced model
assert model is not None and dummy_input is not None
self.graph = TorchModuleGraph(model, dummy_input, traced_model)
self.graph: TorchModuleGraph = TorchModuleGraph(model, dummy_input, traced_model)
self.model = model
self.dependency = dict()
self.build_dependency()
......@@ -123,6 +123,9 @@ class ChannelDependency(Dependency):
elif self.prune_type == 'Batchnorm':
self.target_types.append('BatchNorm2d')
from typing import Dict, Set
self.dependency: Dict[str, Set[str]]
super(ChannelDependency, self).__init__(
model, dummy_input, traced_model)
......@@ -351,7 +354,7 @@ class GroupDependency(Dependency):
----------
model : torch.nn.Module
The model to be analyzed.
data : torch.Tensor
dummy_input : torch.Tensor
The example input data to trace the network architecture.
traced_model : torch._C.Graph
if we alreay has the traced graph of the target model, we donnot
......@@ -418,6 +421,29 @@ class GroupDependency(Dependency):
return 1
return group
def _get_group_norm_condition(self, node_group) -> int:
"""
Get the number of groups for a group norm layer.
Parameters
----------
node_group : NodePyGroup
target node.
Returns
-------
condition: int
the number that layer's num channel
require to be divisible to
"""
node_name = node_group.name
_, leaf_module = get_module_by_name(self.model, node_name)
if isinstance(leaf_module, (PrunerModuleWrapper, PrunerModuleWrapper_v2)):
leaf_module = leaf_module.module
assert isinstance(leaf_module, (torch.nn.GroupNorm))
return leaf_module.num_groups
def build_dependency(self):
"""
Build the channel dependency for the conv layers
......@@ -441,8 +467,11 @@ class GroupDependency(Dependency):
"""
self.groups = {}
for node in self.graph.nodes_py.nodes_op:
if node.op_type == 'Conv2d' or node.op_type == 'ConvTranspose2d':
group = self._get_conv_groups(node)
if node.op_type in ['Conv2d', 'ConvTranspose2d', "GroupNorm"]:
if node.op_type in ['Conv2d', 'ConvTranspose2d']:
group = self._get_conv_groups(node)
elif node.op_type == "GroupNorm":
group = self._get_group_norm_condition(node)
if node.name in self.groups:
# the conv layer whose group is larger than 1 will require that
# it's number of output channel to be divisible by the number of group.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment