Unverified Commit 607d6a91 authored by Ningxin Zheng's avatar Ningxin Zheng Committed by GitHub
Browse files

Error code for Speedup Module (#4173)

parent 8b61e774
......@@ -4,6 +4,7 @@
import logging
import torch
import torch.nn as nn
from .error_code import EmptyLayerError, ShapeMisMatchError, InputsNumberError, OutputTypeError, UnBalancedGroupError
_logger = logging.getLogger(__name__)
......@@ -44,7 +45,6 @@ replace_module = {
}
def convert_to_coarse_mask(t_mask, dim):
"""
Convert the mask tensor to the coarse-grained mask tensor.
......@@ -87,6 +87,7 @@ def no_replace(module, masks):
_logger.debug("no need to replace")
return module
def replace_prelu(prelu, masks):
"""
Parameters
......@@ -102,8 +103,11 @@ def replace_prelu(prelu, masks):
The new prelu module
"""
in_masks, output_mask, weight_mask = masks
assert len(in_masks) == 1
assert isinstance(output_mask, torch.Tensor)
if len(in_masks) != 1:
raise InputsNumberError()
if not isinstance(output_mask, torch.Tensor):
raise OutputTypeError(type(output_mask), torch.Tensor)
in_mask = in_masks[0]
weight_mask = weight_mask['weight']
pruned_in, remained_in = convert_to_coarse_mask(in_mask, 1)
......@@ -112,13 +116,17 @@ def replace_prelu(prelu, masks):
n_remained_out = weight_mask.size(0) - pruned_out.size(0)
remained_in, remained_out = remained_in.to(
prelu.weight.device), remained_out.to(prelu.weight.device)
assert n_remained_in == n_remained_out
if n_remained_in != n_remained_out:
raise ShapeMisMatchError()
if n_remained_in == 0:
return torch.nn.Identity()
new_prelu = torch.nn.PReLU(n_remained_in)
new_prelu.weight.data = torch.index_select(prelu.weight.data, 0, remained_in)
new_prelu.weight.data = torch.index_select(
prelu.weight.data, 0, remained_in)
return new_prelu
def replace_linear(linear, masks):
"""
This function will replace the original linear according to
......@@ -142,8 +150,11 @@ def replace_linear(linear, masks):
"""
in_masks, output_mask, weight_mask = masks
assert isinstance(linear, nn.Linear)
assert len(in_masks) == 1
assert isinstance(output_mask, torch.Tensor)
if len(in_masks) != 1:
raise InputsNumberError()
if not isinstance(output_mask, torch.Tensor):
raise OutputTypeError(type(output_mask), torch.Tensor)
in_mask = in_masks[0]
weight_mask = weight_mask['weight']
......@@ -199,7 +210,8 @@ def replace_batchnorm1d(norm, masks):
# N, C, H, W
_, remained_in = convert_to_coarse_mask(in_mask, 1)
_, remained_out = convert_to_coarse_mask(output_mask, 1)
assert remained_in.size(0) == remained_out.size(0)
if remained_in.size(0) != remained_out.size(0):
raise ShapeMisMatchError()
num_features = remained_in.size(0)
_logger.info("replace batchnorm1d with num_features: %d", num_features)
......@@ -241,7 +253,8 @@ def replace_batchnorm2d(norm, masks):
# N, C, H, W
_, remained_in = convert_to_coarse_mask(in_mask, 1)
_, remained_out = convert_to_coarse_mask(output_mask, 1)
assert remained_in.size(0) == remained_out.size(0)
if remained_in.size(0) != remained_out.size(0):
raise ShapeMisMatchError()
num_features = remained_in.size(0)
_logger.info("replace batchnorm2d with num_features: %d", num_features)
......@@ -261,7 +274,6 @@ def replace_batchnorm2d(norm, masks):
return new_norm
def replace_conv2d(conv, masks):
"""
Replace the original conv with a new one according to the infered
......@@ -285,7 +297,8 @@ def replace_conv2d(conv, masks):
in_masks, output_mask, weight_masks = masks
assert isinstance(conv, nn.Conv2d)
# the conv layer should only have one input tensor
assert len(in_masks) == 1
if len(in_masks) != 1:
raise InputsNumberError()
in_mask = in_masks[0]
......@@ -296,8 +309,8 @@ def replace_conv2d(conv, masks):
n_remained_in = weight_mask.size(1) * conv.groups - pruned_in.size(0)
n_remained_out = weight_mask.size(0) - pruned_out.size(0)
assert n_remained_in == remained_in.size(0)
assert n_remained_out == remained_out.size(0)
if n_remained_in != remained_in.size(0) or n_remained_out != remained_out.size(0):
raise ShapeMisMatchError()
k_size1, k_size2 = conv.kernel_size
# Note: We should resolve the group dependency of the conv layers before
......@@ -331,9 +344,10 @@ def replace_conv2d(conv, masks):
tmp_weight = torch.ones(
n_remained_out, new_inchannel_step, k_size1, k_size2)
tmp_weight = tmp_weight.to(conv.weight.device)
assert n_remained_in % new_inchannel_step == 0
assert n_remained_out % new_outchannel_step == 0
if new_inchannel_step == 0 or new_outchannel_step == 0:
raise EmptyLayerError()
if n_remained_in % new_inchannel_step != 0 or n_remained_out % new_outchannel_step != 0:
raise UnBalancedGroupError()
new_groups = 0
for groupid in range(conv.groups):
......@@ -352,8 +366,9 @@ def replace_conv2d(conv, masks):
assert len(current_output_index) == 0
continue
# check if the number of remained channel of each group are the same
assert len(current_input_index) == new_inchannel_step
assert len(current_output_index) == new_outchannel_step
if len(current_input_index) != new_inchannel_step or len(current_output_index) != new_outchannel_step:
raise UnBalancedGroupError()
# copy the weight into tmp_weight
new_out_start = new_outchannel_step * new_groups
new_out_end = new_out_start + new_outchannel_step
......@@ -386,7 +401,6 @@ def replace_conv2d(conv, masks):
new_conv.bias.data.copy_(torch.index_select(
conv.bias.data, 0, remained_out))
return new_conv
......@@ -410,7 +424,8 @@ def replace_convtranspose2d(convtrans, masks):
"""
in_masks, output_mask, weight_masks = masks
assert isinstance(convtrans, torch.nn.ConvTranspose2d)
assert len(in_masks) == 1
if len(in_masks) != 1:
raise InputsNumberError()
in_mask = in_masks[0]
weight_mask = weight_masks['weight']
......@@ -420,8 +435,9 @@ def replace_convtranspose2d(convtrans, masks):
n_remained_in = weight_mask.size(0) - pruned_in.size(0)
n_remained_out = weight_mask.size(
1) * convtrans.groups - pruned_out.size(0)
assert n_remained_in == remained_in.size(0)
assert n_remained_out == remained_out.size(0)
if n_remained_in != remained_in.size(0) or n_remained_out != remained_out.size(0):
raise ShapeMisMatchError()
k_size1, k_size2 = convtrans.kernel_size
# Note: we should resolve the group dependency of the convtrans layers before
# run into this function
......@@ -448,8 +464,10 @@ def replace_convtranspose2d(convtrans, masks):
n_remained_in, new_outchannel_step, k_size1, k_size2)
tmp_weight = tmp_weight.to(convtrans.weight.device)
assert n_remained_in % new_inchannel_step == 0
assert n_remained_out % new_outchannel_step == 0
if new_inchannel_step == 0 or new_outchannel_step == 0:
raise EmptyLayerError()
if n_remained_in % new_inchannel_step != 0 or n_remained_out % new_outchannel_step != 0:
raise UnBalancedGroupError()
new_groups = 0
for groupid in range(convtrans.groups):
......@@ -471,8 +489,9 @@ def replace_convtranspose2d(convtrans, masks):
assert len(current_output_index) == 0
continue
# check if the number of remained channel of each group are the same
assert len(current_input_index) == new_inchannel_step
assert len(current_output_index) == new_outchannel_step
if len(current_input_index) != new_inchannel_step or len(current_output_index) != new_outchannel_step:
raise UnBalancedGroupError()
# copy the weight into tmp_weight
new_in_start = new_inchannel_step * new_groups
new_in_end = new_in_start + new_inchannel_step
......@@ -505,7 +524,8 @@ def replace_convtranspose2d(convtrans, masks):
def replace_layernorm(layernorm, masks):
in_masks, _, _ = masks
assert isinstance(layernorm, nn.LayerNorm)
assert len(in_masks) == 1
if len(in_masks) != 1:
raise InputsNumberError()
in_mask = in_masks[0]
dim_n = len(in_mask.size())
new_shape = []
......
......@@ -15,6 +15,7 @@ from .infer_mask import AutoMaskInference
from .jit_translate import jit_to_python_function
from ..utils import rand_like_with_shape
_logger = logging.getLogger(__name__)
_logger.setLevel(logging.INFO)
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# Error Code of the speedup
class SpeedupError(Exception):
def __init__(self, msg):
self.msg = msg
def __str__(self):
return str(self.msg)
class EmptyLayerError(SpeedupError):
def __init__(self):
super(EmptyLayerError, self).__init__("Pruning a Layer to empty is not legal")
class ShapeMisMatchError(SpeedupError):
def __init__(self):
super(ShapeMisMatchError, self).__init__("Shape mismatch!")
class InputsNumberError(SpeedupError):
def __init__(self):
super(InputsNumberError, self).__init__("The number of the inputs of the target OP is wrong")
class OutputTypeError(SpeedupError):
def __init__(self, current_type, target_type):
msg = f"The output type should be {str(target_type)}, but {str(current_type)} founded"
super(OutputTypeError, self).__init__(msg)
class UnBalancedGroupError(SpeedupError):
def __init__(self):
msg = "The number remained filters in each group is different"
super(UnBalancedGroupError, self).__init__(msg)
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment