"...resnet50_tensorflow.git" did not exist on "50235dab34cd15f728367807fa776e06d2fcc1a4"
Unverified Commit 7eedec46 authored by Ningxin Zheng's avatar Ningxin Zheng Committed by GitHub
Browse files

Model Speedup Refactor (#3462)

parent 5b99b598
...@@ -140,9 +140,6 @@ Topology Utilities ...@@ -140,9 +140,6 @@ Topology Utilities
.. autoclass:: nni.compression.pytorch.utils.shape_dependency.GroupDependency .. autoclass:: nni.compression.pytorch.utils.shape_dependency.GroupDependency
:members: :members:
.. autoclass:: nni.compression.pytorch.utils.mask_conflict.CatMaskPadding
:members:
.. autoclass:: nni.compression.pytorch.utils.mask_conflict.GroupMaskConflict .. autoclass:: nni.compression.pytorch.utils.mask_conflict.GroupMaskConflict
:members: :members:
......
...@@ -71,7 +71,11 @@ class TorchGraph: ...@@ -71,7 +71,11 @@ class TorchGraph:
def _trace(self, model, dummy_input): def _trace(self, model, dummy_input):
training = model.training training = model.training
model.eval() model.eval()
self.trace = torch.jit.trace(model, dummy_input) kw_args = {}
if torch.__version__ >= '1.6.0':
# only pytorch with version greater than 1.6.0 has the strict option
kw_args['strict'] = False
self.trace = torch.jit.trace(model, dummy_input, **kw_args)
torch._C._jit_pass_inline(self.trace.graph) torch._C._jit_pass_inline(self.trace.graph)
model.train(training) model.train(training)
...@@ -247,6 +251,7 @@ class TorchModuleGraph(TorchGraph): ...@@ -247,6 +251,7 @@ class TorchModuleGraph(TorchGraph):
def __init__(self, model=None, dummy_input=None, traced_model=None): def __init__(self, model=None, dummy_input=None, traced_model=None):
super().__init__(model, dummy_input, traced_model) super().__init__(model, dummy_input, traced_model)
self.global_count = 0 self.global_count = 0
self.reused_module = set()
self.name_to_node, self.input_to_node, self.output_to_node = self._build_graph() self.name_to_node, self.input_to_node, self.output_to_node = self._build_graph()
self._extract_auxiliary_info() self._extract_auxiliary_info()
...@@ -390,9 +395,12 @@ class TorchModuleGraph(TorchGraph): ...@@ -390,9 +395,12 @@ class TorchModuleGraph(TorchGraph):
outputs.append(output_name) outputs.append(output_name)
else: else:
outputs.append(output_name) outputs.append(output_name)
unique_outputs = list(set(outputs))
# remove the dumplicated output names
unique_outputs.sort(key=outputs.index)
nodepy = NodePyGroup(node_name, unique_name, module_type, op_type, nodepy = NodePyGroup(node_name, unique_name, module_type, op_type,
node_group, inputs=list(inputs), outputs=list(outputs)) node_group, inputs=list(inputs), outputs=unique_outputs)
return nodepy return nodepy
def _extract_cat_info(self, node_group, cpp_node): def _extract_cat_info(self, node_group, cpp_node):
...@@ -724,6 +732,8 @@ class TorchModuleGraph(TorchGraph): ...@@ -724,6 +732,8 @@ class TorchModuleGraph(TorchGraph):
unique_name = module_name unique_name = module_name
if use_count > 0: if use_count > 0:
unique_name = module_name + '.%d' % use_count unique_name = module_name + '.%d' % use_count
self.reused_module.add(unique_name)
self.reused_module.add(module_name)
node_group = self._expand_module_node( node_group = self._expand_module_node(
node, module_name, unique_name, module_to_type[module_name], node, module_name, unique_name, module_to_type[module_name],
node_cpps, input_to_node, output_to_node, 'module') node_cpps, input_to_node, output_to_node, 'module')
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import logging
import torch
import torch.nn as nn
from ..utils import randomize_tensor, torch_float_dtype, torch_integer_dtype
_logger = logging.getLogger(__name__)
_logger.setLevel(logging.INFO)
STD_DELTA = 1e-6
class AutoMaskInference:
def __init__(self, module, dummy_input, in_masks=None, weight_mask=None, \
output_mask=None, name=None, in_constants=None, state_dict=None, batch_dim=0):
"""
This class will infer the mask of the target module automatically.
This update_direct_sparsity will infer the output mask according
to the input masks, in constrast, update_indirect_sparsity will
infer the input masks according to given output masks. The newly
found sparsity will be incrementally updated to the original in_masks
and output_mask.
Parameters
----------
module: torch.nn.Module/function
The target module to infer the mask. Need to be callable.
dummy_input: torch.Tensor/list of Tensor
The dummy_input of the target module.
in_masks: list of torch.Tensor
The input masks of the target module, if in_masks is not None, then
update_direct_sparsity and update_indirect_sparsity will incrementally
update the given in_masks, else, AutoMaskInference will create a new
in_masks for the target module.
output_mask: torch.Tensor
The output mask of the target module. Similar to in_masks, if output_mask
is not None, then update_direct_sparsity and update_indirect_sparsity will
incrementally update the given output_mask, else AutoMaskInference will create
one output_mask for the target module.
weight_mask: dict of the weight masks
The weight masks of the target module, the key is the corresponding name of
the mask. For example: {'weight':torch.ones(1000, 1000), bias:torch.ones(1000)}
name: str
Name of the target module.
in_constants: list of torch.Tensor
The correponding constant values of the in_masks.
state_dict: dict of torch.Tensor
The original values of the weights.
batch_dim: int
The index of the batch dimension of the input tensors.
"""
errmsg = '%s is not callable, should pass the nn.Module/function' % str(
module)
assert callable(module), errmsg
self.module = module
# Initialize the dummy_input
if isinstance(dummy_input, list):
# if there are multiple input variables
self.dummy_input = dummy_input
else:
# if there is only one input variable
self.dummy_input = [dummy_input]
# Initialize the masks for input tensors
self.in_masks = in_masks if in_masks is not None else [
None] * len(self.dummy_input)
self.in_constants = in_constants if in_constants is not None else [
torch.zeros_like(x) for x in dummy_input]
for in_id, _ in enumerate(self.in_masks):
if self.in_masks[in_id] is None and \
isinstance(self.dummy_input[in_id], torch.Tensor):
# if the input mask is None then create a all-ones mask for corresponding input tensor
self.in_masks[in_id] = torch.ones_like(self.dummy_input[in_id])
# ones_like will put the created mask on the same device with the dummy_input
# Initialize the mask for output tensors
self.output = self.module(*dummy_input)
# self.output.requires_grad_()
if output_mask is not None:
# assume the given output mask is right
self.output_mask = output_mask
else:
if isinstance(self.output, torch.Tensor):
self.output_mask = torch.ones_like(self.output)
elif isinstance(self.output, list) or isinstance(self.output, tuple):
self.output_mask = []
for o_tensor in self.output:
if isinstance(o_tensor, torch.Tensor):
self.output_mask.append(torch.ones_like(o_tensor))
else:
# if one of the outputs is not tensor, set the corresponding
# mask to None
self.output_mask.append(None)
else:
self.output_mask = None
# Initialize the mask for the parameters
self.weights = {}
self.weight_mask = {}
if weight_mask:
self.weight_mask.update(weight_mask)
if isinstance(self.module, nn.Module):
# the function should not has parameters
# get all the parameter tensors of the target module
for name, para in module.named_parameters():
self.weights[name] = para
if name not in self.weight_mask:
self.weight_mask[name] = torch.ones_like(para.data)
self.name = name
self.state_dict = state_dict
# TODO support the other batch dimension in the future
self.batch_dim = batch_dim
def random_init(self, start=0.1, end=8.0):
"""
Random initialize the weights of the module. The value of
the tensor will not affect the mask auto inference.
"""
# currently we set the random range to 0.1-8.0 because of the ReLU6,
# if we use a range that far larger than 6, it may infer a wrong mask
# when the confidence is low. In the future, we will add the mask inference
# rules for ReLU6 to break this range constraint.
with torch.no_grad():
for tensor in self.dummy_input:
if isinstance(tensor, torch.Tensor) and len(tensor.size()) > 0:
# if the tensor is a scalar, then skip this tensor
randomize_tensor(tensor, start, end)
for para in self.weights:
randomize_tensor(self.weights[para].data, start, end)
def zero_grad(self):
"""
Set the gradient of the weight, input tensor to be zeros.
"""
with torch.no_grad():
# set the weight's gradient to zero
if isinstance(self.module, nn.Module):
self.module.zero_grad()
# also zero the gradient of the input tensors
for tensor in self.dummy_input:
if isinstance(tensor, torch.Tensor):
if tensor.grad is not None:
tensor.grad.data.zero_()
def requires_grad_(self, flag=True):
"""
Set the requires_grad of input tensor and parameters to flag.
"""
for t_in in self.dummy_input:
if isinstance(t_in, torch.Tensor) and t_in.dtype in torch_float_dtype:
# only float type can require the gradient
# enable the auto gradient
t_in.requires_grad_(flag)
for para_name in self.weights:
if self.weights[para_name].dtype in torch_float_dtype:
self.weights[para_name].requires_grad_(flag)
def apply_mask(self):
self.__apply_input_mask()
self.__apply_weight_mask()
def __apply_input_mask(self):
"""
Apply the mask of the input tensor.
"""
with torch.no_grad():
# apply the input mask
for tid, in_tensor in enumerate(self.dummy_input):
if isinstance(in_tensor, torch.Tensor) and self.in_masks[tid] is not None:
in_tensor.data = in_tensor.data * \
self.in_masks[tid] + \
(1-self.in_masks[tid]) * self.in_constants[tid]
def __apply_weight_mask(self):
"""
Apply the weight mask of this module.
"""
with torch.no_grad():
# apply the weight mask
for para in self.weights:
if para in self.weight_mask:
self.weights[para].data *= self.weight_mask[para].data
def isconstants(self, tout):
"""
Find the constants in the tensor tout. This function return a mask tensor that
indicates if a value in tout is a constant, and return one more tensor to indicate
that the values of the constant.
Paramters
---------
tout: torch.Tensor
The target output tensor to find the constants
Returns
-------
mask: torch.Tensor
The mask tensor(same shape with tout) that indicates that whether
the correponding value is a constant.
constant: torch.Tensor
The mask tensot(same shape with tout) that indicates the values of
the constants in the tout.
"""
assert isinstance(tout, torch.Tensor)
out_mask = torch.ones_like(tout)
constant = torch.zeros_like(tout)
# judge if tout is a scalar(tensor that only have one value)
if len(tout.size()) == 0:
# tout is a scalar tensor, for the scalar tensor, we take
# this scalar as a constant, usually, the scalar tensor is returned
# by the size() function
constant = tout
return out_mask, constant
if tout.dtype in torch_integer_dtype:
# Pytorch cannot use torch.mean and torch.std to process
# intergers :( , so if dtype of the input tensor is integer, we need
# check if is the constant by ourselves
# Note: the first dimension should be the batch dimension
same = tout[:] == tout[0]
reduced = torch.sum(same, dim=0)
is_constant = reduced == tout.size(0)
out_mask[:, is_constant] = 0
constant[:, is_constant] = tout[0][is_constant]
else:
# calculate the std of the output among batch dimension
std = torch.std(tout, dim=0)
# calculate the mean value of the output among the batch dimension
mean = torch.mean(tout, dim=0)
mask_pos = std < STD_DELTA
out_mask[:, mask_pos] = 0
constant[:, mask_pos] = mean[mask_pos]
return out_mask, constant
def update_indirect_sparsity(self):
"""
This function will update the indirect sparsity. To explain what's
indirect sparsity, for example, there is two tensors TA and TB, and
we perform the calculation: TC = TA x TB in which TC is also a tensor.
Once some values in TA are masked to zeros, then the corresponding
positions in TB are also potential sparsities, because these have no
effect of the final output(the gradient of these positions in TB equal
to 0 all the time). This function it to fine the potential sparsity caused
by other sparsity(we call it indirect sparsity here). Basically we can find
these potential sparsity through gradient.
"""
# Each node only update the output mask when we backwards
# update the output mask, this is because that some op may
# have the broadcast operation, for example, OP A's output
# tensor may be taken by two OPs(B, C) as inputs. So we cannot
# directly update the input mask at the OP B or C. We can only
# update the mask of C's output tensor only when B and C are
# already updated(gradient are already calculated and added to
# C's output tensor).
# Besides, updating the mask of C's output tensor equals to updating
# the input mask of OP B and C.
if isinstance(self.output, torch.Tensor) and self.output.grad is not None:
# if output have gradient which means this node has successor
# nodes and the successor nodes have already update their indirect
# sparsity
# we can mask the values whose gradient is always zeros
gradient_sum = torch.sum(torch.abs(self.output.grad.data), dim=0)
_grad_zero = gradient_sum == 0
for batchid in range(self.output.size(0)):
# set the same mask value for the whole batche
self.output_mask[batchid][_grad_zero] = 0
elif isinstance(self.output, tuple) or isinstance(self.output, list):
assert isinstance(self.output_mask, (tuple, list))
for oid, tout in enumerate(self.output):
errmsg = 'The output only support tensor/list of tensors'
assert isinstance(tout, torch.Tensor), errmsg
gradient_sum = torch.sum(
torch.abs(self.output.grad.data), dim=0)
_grad_zero = gradient_sum == 0
for batchid in range(self.output.size(0)):
# set the same mask value for the whole batch
self.output_mask[oid][batchid][_grad_zero] = 0
self.requires_grad_(True)
# Forward inference with auto gradient enabled
# Note: tensors that need gradient cannot be used in the in-place operator
self.random_init()
self.apply_mask()
# Some operator may have the in_place operations, so we need to clone the input
# before passing to the self.module
tmp_dummy_input = [x.clone() if isinstance(
x, torch.Tensor) else x for x in self.dummy_input]
output = self.module(*tmp_dummy_input)
if output.grad_fn is None:
# the output does not have the gradient function
return
# Note: output maybe tensor or list/tuple of tensors
if isinstance(output, torch.Tensor):
output.backward(self.output_mask)
elif isinstance(output, list) or isinstance(output, tuple):
for tid, t_out in enumerate(output):
t_out.backward(self.output_mask[tid])
# update the sparsity of the paramters
for para_name in self.weights:
grad_zero = self.weights[para_name].grad.data == 0
self.weight_mask[para_name][grad_zero] = 0
def update_direct_sparsity(self):
# we don't need the gradient in the forward inference
out_mask = None
constant = None
with torch.no_grad():
# Note: we need randomly init the input one more time here!
# Because some operation have the in-place operation, such as relu_,
# the in-place operation may modify or write 0s into the dummy_input
self.random_init()
# apply the mask for the input tensor and the weight tensor
self.apply_mask()
# Note: due to the in-place operator, such as relu_,
# ori_out may be the same tensor with dummy_input,
# so we use clone and detach to create a new tensor with
# the same values.
out = self.module(*self.dummy_input)
if isinstance(out, torch.Tensor):
out_mask, constant = self.isconstants(out.clone().detach())
elif isinstance(out, tuple) or isinstance(out, list):
out_mask = []
constant = []
for tout in out:
_mask, _constant = self.isconstants(tout.clone().detach())
out_mask.append(_mask)
constant.append(_constant)
else:
_logger.warning(
'Only support the OP whose output is tensor/tuple of tensor/list of tensor')
# We also need random the parameters of the module, because if the weight of the model has
# a unmasked 0, then our out sparsity inference may be wrong
# However, after radomizing the weight/parameters, the constant in the output tensors may
# be different from the constants that calculated from its original stata_dict. However,
# so to get the right constant to eliminate the bias between model before and after sparsity
# inference, we need to reload its state_dict and recalculate the constant
# Currently we also get the constant values at the same time when infering the mask, in
# the future, we will separate the constant inference into a single graph pass.
if len(self.weights) > 0 and self.state_dict is not None:
self.module.load_state_dict(self.state_dict)
# apply weight mask
self.__apply_weight_mask()
out = self.module(*self.dummy_input).clone().detach()
if isinstance(out, torch.Tensor):
constant = torch.zeros_like(out)
constant_pos = out_mask == 0
constant[constant_pos] = out[constant_pos]
elif isinstance(out, (list, tuple)):
constant = []
for i, tout in enumerate(out):
_tmp = torch.zeros_like(tout)
sparsity_pos = out_mask[i] == 0
_tmp[sparsity_pos] = tout[sparsity_pos]
constant.append(_tmp)
if isinstance(out_mask, torch.Tensor):
assert isinstance(self.output_mask, torch.Tensor)
self.output_mask *= out_mask
elif isinstance(out_mask, list):
for i, _ in enumerate(out_mask):
self.output_mask[i] *= out_mask[i]
else:
_logger.warning('There is no output sparsity')
# also save the out_constant
self.out_constant = constant
def get_masks(self):
return (self.in_masks, self.output_mask, self.weight_mask)
This diff is collapsed.
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import re
import logging
from functools import partial
import torch
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
def translate_list(list_node, speedup=None):
"""
Get the list of values from the list construct node.
Parameters
---------
list_node: Torch.C.Value
The cpp node of the target list.
speedup: ModuleSpeed
The Module speedup module.
Returns
-------
values: list
The list of values in the target cpp list node.
"""
# the node that create the list
create_node = list_node.node()
assert create_node.kind() == 'prim::ListConstruct'
inputs = list(create_node.inputs())
values = []
for _i in inputs:
debugName = _i.debugName()
if speedup is not None and debugName in speedup.internal_result:
# this value is the result of the other nodes, such as
# ate::size
values.append(speedup.internal_result[debugName].item())
else:
# if the corresponding value is a constant
values.append(_i.toIValue())
return values
def parse_constant(cvalue, speedup):
"""
Parse the constant values from this Node
Parameters
----------
cvalue: Torch.C.Value
The cpp node of the target constant value.
speedup: ModelSpeedup
The Model speedup module.
Returns
-------
value: int/float/tensor
The constant values parsed from the node.
"""
logger.debug('Try to parse the constant value: %s', cvalue.debugName())
if cvalue.toIValue() is not None:
return cvalue.toIValue()
if cvalue.debugName() in speedup.internal_result:
return speedup.internal_result[cvalue.debugName()]
# Get the operator node of the this value
op_node = cvalue.node()
inputs = op_node.inputs()
input_values = [parse_constant(_i, speedup) for _i in inputs]
func = trans_from_jit_to_python[op_node.kind()](op_node, speedup)
return func(*input_values)
def dropout_python(node, speedup):
return torch.dropout
def flatten_python(node, speedup):
c_node = node.key_node
inputs = list(c_node.inputs())
start_dim = inputs[1].toIValue()
end_dim = inputs[2].toIValue()
new_flatten = partial(torch.flatten, start_dim=start_dim, end_dim=end_dim)
return new_flatten
def relu_inplace_python(node, speedup):
return torch.relu_
def relu_python(node, speedup):
return torch.relu
def sigmoid_python(node, speedup):
return torch.sigmoid
def mean_python(node, speedup):
c_node = node.key_node
inputs = list(c_node.inputs())
dim_list = translate_list(inputs[1], speedup)
keep_dim = inputs[2].toIValue()
new_mean = partial(torch.mean, dim=tuple(dim_list), keepdim=keep_dim)
return new_mean
def add_python(node, speedup):
c_node = node.key_node
inputs = list(c_node.inputs())
constant = None
for i in range(2):
input_i = inputs[i]
debug_name = input_i.debugName()
if debug_name not in speedup.internal_result:
# this input is a constant value
# TODO: what if this input is a constant tensor
if input_i.toIValue() is not None:
constant = parse_constant(input_i, speedup)
break
if constant is None:
return torch.add
else:
new_add = partial(torch.add, constant)
return new_add
def floor_div_python(node, speedup):
c_node = node.key_node
inputs = list(c_node.inputs())
divisor = inputs[1]
constant = None
if divisor.debugName() not in speedup.internal_result:
# divisor is a constant value/tensor
constant = parse_constant(divisor, speedup)
if constant is None:
return torch.floor_divide
else:
new_op = partial(torch.floor_divide, other=constant)
return new_op
def mul_python(node, speedup):
c_node = node.key_node
inputs = list(c_node.inputs())
constant = None
for i in range(2):
input_i = inputs[i]
debug_name = input_i.debugName()
if debug_name not in speedup.internal_result:
constant = parse_constant(input_i, speedup)
# both two inputs cannot be constants at the same time
break
if constant is None:
return torch.mul
else:
new_mul = partial(torch.mul, constant)
return new_mul
def transpose_python(node, speedup):
return torch.t
def transpose2_python(node, speedup):
c_node = node.key_node
inputs = list(c_node.inputs())
dim_1 = inputs[1].toIValue()
dim_2 = inputs[2].toIValue()
new_transpose = partial(torch.transpose, dim0=dim_1, dim1=dim_2)
return new_transpose
def matmul_python(node, speedup):
return torch.matmul
def div_python(node, speedup):
# The second input parameter of torch.div can be a
# tensor or a constant, if it is a constant, we need
# to return
c_node = node.key_node
inputs = list(c_node.inputs())
if inputs[1].debugName() in speedup.internal_result:
# the second input parameters is the output of the other
# nodes
return torch.div
else:
other = inputs[1].toIValue()
new_div = partial(torch.div, other=other)
return new_div
def softmax_python(node, speedup):
c_node = node.key_node
inputs = list(c_node.inputs())
dim = inputs[1].toIValue()
new_softmax = partial(torch.softmax, dim=dim)
return new_softmax
def contiguous_python(node, speedup):
class contiguousModule(torch.nn.Module):
def forward(self, x):
return x.contiguous()
return contiguousModule()
def gelu_python(node, speedup):
return torch.nn.GELU()
def avgpool2d_python(node, speedup):
c_node = node.key_node
inputs = list(c_node.inputs())
kernel_size = translate_list(inputs[1], speedup)
stride = translate_list(inputs[2], speedup)
padding = translate_list(inputs[3], speedup)
new_avgpool = partial(torch.nn.functional.avg_pool2d,
kernel_size=kernel_size, stride=stride, padding=padding)
return new_avgpool
def adaptive_avgpool_python(node, speedup):
c_node = node.key_node
inputs = list(c_node.inputs())
output_size = translate_list(inputs[1], speedup)
new_avgpool = torch.nn.AdaptiveAvgPool2d(output_size)
return new_avgpool
def tupleunpack_python(node, speedup):
# Note: tuple unpack should only exists at the
# the end of the model, and is no need to replace/propagate mask
return None
def num2tensor_python(node, speedup):
return torch.nn.Identity()
def exp_python(node, speedup):
return torch.exp
def squeeze_python(node, speedup):
c_node = node.key_node
inputs = list(c_node.inputs())
dim = None
if len(inputs) > 1:
dim = parse_constant(inputs[1], speedup)
new_squeeze = partial(torch.squeeze, dim=dim)
return new_squeeze
##########################################################
# Split Line
# Following module/functions cannot be translated into a
# single function, so we use torch.nn.Module to wrap the
# the core function, and return the torch.nn.Module instead
##########################################################
def slice_python(node, speedup):
class SliceMoudle(torch.nn.Module):
def __init__(self, sliceobj):
super(SliceMoudle, self).__init__()
self.sliceobj = sliceobj
def forward(self, x, *args):
# args is for the slice dimension and indexes, however,
# we already get them from the cpp nodes. Note, though, we
# don't need the slice indexes any more, we cannot remove this
# parameter here, because, there may be multiple inputs passed from
# previous nodes such as aten::size
logger.info('Model has Slice operation, and the operand size=%s, Slice object:%s', str(
x.size()), str(self.sliceobj))
return x[self.sliceobj]
c_node = node.key_node
inputs = list(c_node.inputs())
slice_dim = parse_constant(inputs[1], speedup)
slice_start = parse_constant(inputs[2], speedup)
slice_end = parse_constant(inputs[3], speedup)
slice_step = parse_constant(inputs[4], speedup)
slice_obj = slice(slice_start, slice_end, slice_step)
slice_list = []
for _ in range(slice_dim):
slice_list.append(slice(None, None))
logger.info('Slice dim:%s, Slice obj:%s', str(slice_dim), str(slice_obj))
slice_list.append(slice_obj)
return SliceMoudle(tuple(slice_list))
def select_python(node, speedup):
class SelectModule(torch.nn.Module):
def __init__(self, dim, index):
super(SelectModule, self).__init__()
self.dim = dim
self.index = index
def forward(self, x):
return x.select(self.dim, self.index)
c_node = node.key_node
inputs = list(c_node.inputs())
dim = inputs[1].toIValue()
index = inputs[2].toIValue()
return SelectModule(dim, index)
def size_python(node, speedup):
# return None
class SizeMoudle(torch.nn.Module):
def __init__(self, sizedim):
super(SizeMoudle, self).__init__()
self.sizedim = sizedim
def forward(self, x):
return torch.as_tensor([x.size(self.sizedim)], dtype=torch.long)
# return torch.tensor(x.size(self.sizedim))
c_node = node.key_node
inputs = list(c_node.inputs())
size_dim = inputs[1].toIValue()
return SizeMoudle(size_dim)
def toint_python(node, speedup):
class ToIntModule(torch.nn.Module):
def forward(self, x):
return x.to(torch.int)
return ToIntModule()
def view_python(node, speedup):
class ViewModule(torch.nn.Module):
def __init__(self, shape):
super(ViewModule, self).__init__()
self.shape = shape
logger.info('View Module output size: %s', str(self.shape))
def forward(self, *args):
return args[0].view(self.shape)
c_node = node.key_node
inputs = list(c_node.inputs())
shape = translate_list(inputs[1], speedup)
return ViewModule(shape)
def reshape_python(node, speedup):
class ReshapeModule(torch.nn.Module):
def __init__(self, shape):
super(ReshapeModule, self).__init__()
self.shape = shape
logger.info('Reshape Module output size: %s', str(self.shape))
def forward(self, *args):
return args[0].view(self.shape)
c_node = node.key_node
inputs = list(c_node.inputs())
shape = translate_list(inputs[1], speedup)
return ReshapeModule(shape)
def permute_python(node, speedup):
class PermuteModule(torch.nn.Module):
def __init__(self, dimlist):
super(PermuteModule, self).__init__()
self.dimlist = dimlist
def forward(self, x):
return x.permute(self.dimlist)
c_node = node.key_node
inputs = list(c_node.inputs())
dim_list = translate_list(inputs[1], speedup)
return PermuteModule(dim_list)
def getattr_python(node, speedup):
"""
Note: Ops started with Prim:: is not taken as the key node,
so we directly pass the Cpp node into this funciton.
Parameters
----------
node: torch._C.Node
The cpp node of prim::Getattr
speedup: ModelSpeedup
The corresponding speedup object.
"""
class GetModule(torch.nn.Module):
def __init__(self, key):
super(GetModule, self).__init__()
self.key = key
def forward(self, obj):
logger.info('Get attribute: %s', self.key)
return getattr(obj, self.key)
# get the name of the attribute, for example
# prim::GetAttr[name="module_list"](%self.1)
assert node.kind() == 'prim::GetAttr'
pattern = '\[name=\"(.*?)\"\]'
key_words = re.findall(pattern, str(node))
assert len(key_words) == 1
return GetModule(key_words[0])
def upsample_bilinear2d_python(node, speedup):
class UpsampleModule(torch.nn.Module):
def __init__(self, size_list, scale_list):
super(UpsampleModule, self).__init__()
self.size_list = size_list
self.scale_list = scale_list
def forward(self, *args):
"""
The first input of args is the target tensor to upsample
, the following parameters is useless, because we already
get the size_list and the scale_list by parsing the cpp_nodes.
"""
return torch.nn.functional.upsample_bilinear(args[0],
size=self.size_list, scale_factor=self.scale_list)
c_node = node.key_node
inputs = list(c_node.inputs())
size_list_node = inputs[1].node()
scale_list_node = inputs[3].node()
size_list = None
scale_list = None
if size_list_node.kind() == 'prim::ListConstruct':
size_list = translate_list(inputs[1], speedup)
if scale_list_node.kind() == 'prim::ListConstruct':
scale_list = translate_list(inputs[3], speedup)
return UpsampleModule(size_list, scale_list)
def typeas_python(node, speedup):
"""
currently only support type_as float.
TODO: support more types in the type_as, need to figure out
how to get the scalar type from torch._C.TensorType.
"""
class TypeasModule(torch.nn.Module):
def __init__(self, dtype=torch.float):
self.example = torch.zeros(1, dtype=dtype)
def forward(self, x):
return x.type_as(self.example)
return TypeasModule()
def to_python(node, speedup):
# for the time being, only device parameters are supported
class ToModule(torch.nn.Module):
def __init__(self, device):
super(ToModule, self).__init__()
def forward(self, x):
return x.to(device)
c_node = node.key_node
inputs = list(c_node.inputs())
device = inputs[3].toIValue()
return ToModule(device)
def cat_python(node, speedup):
class CatModule(torch.nn.Module):
def __init__(self, cat_dim):
super(CatModule, self).__init__()
self.cat_dim = cat_dim
def forward(self, *args):
return torch.cat(args, dim=self.cat_dim)
c_node = node.key_node
inputs = list(c_node.inputs())
dim = inputs[1].toIValue()
return CatModule(dim)
trans_from_jit_to_python = {
'aten::add': add_python,
'aten::add_': add_python,
'aten::mul': mul_python,
'aten::mul_': mul_python,
'aten::relu': relu_python,
'aten::relu_': relu_inplace_python,
'aten::sigmoid': sigmoid_python,
'aten::sigmoid_': sigmoid_python,
# tanh behaives like relu
'aten::tanh': relu_python,
'aten::tanh_': relu_python,
'aten::flatten': flatten_python,
'aten::mean': mean_python,
'aten::dropout': dropout_python,
'aten::slice': slice_python,
'aten::select': select_python,
'aten::size': size_python,
'aten::t': transpose_python,
'aten::transpose': transpose2_python,
'aten::Int': toint_python,
'aten::view': view_python,
'aten::reshape': reshape_python,
'aten::permute': permute_python,
'aten::matmul': matmul_python,
'aten::div': div_python,
'aten::floor_divide': floor_div_python,
'aten::softmax': softmax_python,
'aten::contiguous': contiguous_python,
'aten::gelu': gelu_python,
'aten::cat': cat_python,
'aten::avg_pool2d': avgpool2d_python,
'aten::max_pool2d': avgpool2d_python,
'aten::adaptive_avg_pool2d': adaptive_avgpool_python,
'aten::to': to_python,
'aten::type_as': typeas_python,
'aten::upsample_bilinear2d': upsample_bilinear2d_python,
'aten::exp': exp_python,
'aten::squeeze': squeeze_python,
'prim::TupleUnpack': tupleunpack_python,
'prim::ListUnpack': tupleunpack_python,
'prim::NumToTensor': num2tensor_python,
'prim::GetAttr': getattr_python
}
def jit_to_python_function(node, speedup):
"""
Return a callable object to inference the mask according to the
node.op_type.
Parameters
---------
node: NodeGroup
The target node to inference the mask
speedup: ModelSpeedup
The speedup object of the target model.
Returns
------
func: callable object(nn.Module/function)
Return the translated function that used to inference the mask
, if current op_type is not supported, then we return None.
"""
logger.debug(
'Translate C function %s into its python version', node.op_type)
if node.op_type not in trans_from_jit_to_python:
logger.error(
'%s is not Supported! Please report an issue at https://github.com/microsoft/nni. Thanks~', node.op_type)
# return None to skip the mask inference for this node
return None
return trans_from_jit_to_python[node.op_type](node, speedup)
from .utils import *
\ No newline at end of file
...@@ -4,10 +4,10 @@ import os ...@@ -4,10 +4,10 @@ import os
import logging import logging
import torch import torch
import numpy as np import numpy as np
from .shape_dependency import ChannelDependency, GroupDependency, CatPaddingDependency, InputChannelDependency from .shape_dependency import ChannelDependency, GroupDependency, InputChannelDependency
from .utils import get_module_by_name from .utils import get_module_by_name
# logging.basicConfig(level = logging.DEBUG) # logging.basicConfig(level = logging.DEBUG)
_logger = logging.getLogger(__name__) _logger = logging.getLogger('FixMaskConflict')
def fix_mask_conflict(masks, model=None, dummy_input=None, traced=None): def fix_mask_conflict(masks, model=None, dummy_input=None, traced=None):
...@@ -21,7 +21,7 @@ def fix_mask_conflict(masks, model=None, dummy_input=None, traced=None): ...@@ -21,7 +21,7 @@ def fix_mask_conflict(masks, model=None, dummy_input=None, traced=None):
A dict object that stores the masks or the path of the mask file A dict object that stores the masks or the path of the mask file
model : torch.nn.Module model : torch.nn.Module
model to fix the mask conflict model to fix the mask conflict
dummy_input : torch.Tensor dummy_input : torch.Tensor/list of tensors/dict of tensors
input example to trace the model input example to trace the model
traced : torch._C.torch.jit.TopLevelTracedModule traced : torch._C.torch.jit.TopLevelTracedModule
the traced model of the target model, is this parameter is not None, the traced model of the target model, is this parameter is not None,
...@@ -48,9 +48,7 @@ def fix_mask_conflict(masks, model=None, dummy_input=None, traced=None): ...@@ -48,9 +48,7 @@ def fix_mask_conflict(masks, model=None, dummy_input=None, traced=None):
masks = fix_group_mask.fix_mask() masks = fix_group_mask.fix_mask()
fix_channel_mask = ChannelMaskConflict(masks, model, dummy_input, traced) fix_channel_mask = ChannelMaskConflict(masks, model, dummy_input, traced)
masks = fix_channel_mask.fix_mask() masks = fix_channel_mask.fix_mask()
padding_cat_mask = CatMaskPadding(masks, model, dummy_input, traced) return masks
masks = padding_cat_mask.fix_mask()
return masks, fix_channel_mask.conv_prune_dim
class MaskFix: class MaskFix:
...@@ -78,70 +76,6 @@ class MaskFix: ...@@ -78,70 +76,6 @@ class MaskFix:
torch.save(self.masks, path) torch.save(self.masks, path)
class CatMaskPadding(MaskFix):
def __init__(self, masks, model, dummy_input=None, traced=None):
"""
CatMaskPadding find the layers whose output tensor is passed
to the same cat operation. The cat operation concatnates the
masks of the input tensors as the output mask, so when some
of the input layers of the cat operation are not pruned, we still
need to pass the masks of these non-pruned layers(the mask are
all ones) to the cat operation to ensure the shape of the output
mask is right.
Parameters
----------
masks : dict
a dict object that stores the masks
model : torch.nn.Module
model to fix the mask conflict
dummy_input : torch.Tensor
input example to trace the model
traced : torch._C.torch.jit.TopLevelTracedModule
the traced model of the target model, is this parameter is not None,
we donnot use the model and dummpy_input to get the trace graph.
"""
super(CatMaskPadding, self).__init__(masks, model, dummy_input, traced)
def fix_mask(self):
cat_padding_depen = CatPaddingDependency(
self.model, self.dummy_input, self.traced)
name_to_module = {}
for name, module in self.model.named_modules():
name_to_module[name] = module
depen = cat_padding_depen.dependency_sets
for layers in depen:
device = None
count = 0
for layer in layers:
if layer in self.masks:
count += 1
if device is None:
device = self.masks[layer]['weight'].device
if count == 0:
# no layer is pruned
continue
elif count == len(layers):
# all the layers have been pruned
continue
# pad the mask for the non-pruned layers
for layer in layers:
if layer in self.masks:
continue
module = name_to_module[layer]
w_shape = module.weight.data.size()
w_mask = torch.ones(w_shape).to(device)
b_mask = None
if hasattr(module, 'bias') and module.bias is not None:
# module.bias may be None
b_shape = module.bias.data.size()
b_mask = torch.ones(b_shape).to(device)
self.masks[layer] = {'weight': w_mask, 'bias': b_mask}
return self.masks
class GroupMaskConflict(MaskFix): class GroupMaskConflict(MaskFix):
def __init__(self, masks, model=None, dummy_input=None, traced=None): def __init__(self, masks, model=None, dummy_input=None, traced=None):
""" """
...@@ -172,9 +106,11 @@ class GroupMaskConflict(MaskFix): ...@@ -172,9 +106,11 @@ class GroupMaskConflict(MaskFix):
group_depen = GroupDependency( group_depen = GroupDependency(
self.model, self.dummy_input, self.traced) self.model, self.dummy_input, self.traced)
depens = group_depen.dependency depens = group_depen.dependency
min_groups = group_depen.min_groups
_logger.info(depens) _logger.info(depens)
for layername in depens: for layername in depens:
group = depens[layername] group_max = depens[layername]
group_min = min_groups[layername]
if layername not in self.masks: if layername not in self.masks:
# this layer not pruned # this layer not pruned
continue continue
...@@ -187,29 +123,43 @@ class GroupMaskConflict(MaskFix): ...@@ -187,29 +123,43 @@ class GroupMaskConflict(MaskFix):
# In fine-grained pruning, skip this layer # In fine-grained pruning, skip this layer
_logger.info('Layers %s using fine-grained pruning', layername) _logger.info('Layers %s using fine-grained pruning', layername)
continue continue
assert shape[0] % group == 0 assert shape[0] % group_max == 0
# Find the number of masked filter for each group (mini_masked). # Find the number of masked filter for each group (mini_masked).
# Because we have to keep the pruned filter can still # Because we have to keep the pruned filter can still
# be divided into the same number of groups, so we only can # be divided into the same number of groups, so we only can
# prune mini_masked filters for each group. # prune mini_masked filters for each group.
step = shape[0] / group step = shape[0] / group_max
group_masked = [] group_masked = []
for i in range(group): for i in range(group_max):
_start = step * i _start = step * i
_end = step * (i + 1) _end = step * (i + 1)
_tmp_list = list( _tmp_list = list(
filter(lambda x: _start <= x and x < _end, all_zeros)) filter(lambda x: _start <= x and x < _end, all_zeros))
group_masked.append(_tmp_list) group_masked.append(_tmp_list)
mini_masked = min([len(x) for x in group_masked]) mini_masked = min([len(x) for x in group_masked])
need_unmask = set()
for gm in group_masked: for gm in group_masked:
for i in range(mini_masked, len(gm)): for i in range(mini_masked, len(gm)):
# To keep the output channel number still being divisible to # To keep the output channel number still being divisible to
# groups, we set the masks of following filters to be zero. # groups, we set the masks of following filters to be zero.
pos = gm[i] pos = gm[i]
self.masks[layername]['weight'][pos] = torch.ones( need_unmask.add(pos)
shape[1:]) step = shape[0] / group_min
if 'bias' in self.masks[layername] and self.masks[layername]['bias'] is not None: for i in range(group_min):
self.masks[layername]['bias'][pos] = 1 _start = step * i
_end = step * (i+1)
_tmp_list = list(
filter(lambda x: _start <= x and x < _end, all_zeros))
if len(_tmp_list) == step:
# if the whole group is removed, then we don't have to unmask for
# the filters in this group
for pos in _tmp_list:
if pos in need_unmask:
need_unmask.remove(pos)
for pos in need_unmask:
self.masks[layername]['weight'][pos] = torch.ones(shape[1:])
if hasattr(self.masks[layername], 'bias'):
self.masks[layername]['bias'][pos] = 1
return self.masks return self.masks
...@@ -234,9 +184,14 @@ class ChannelMaskConflict(MaskFix): ...@@ -234,9 +184,14 @@ class ChannelMaskConflict(MaskFix):
super(ChannelMaskConflict, self).__init__( super(ChannelMaskConflict, self).__init__(
masks, model, dummy_input, traced) masks, model, dummy_input, traced)
self.conv_prune_dim = detect_mask_prune_dim(masks, model) self.conv_prune_dim = detect_mask_prune_dim(masks, model)
_logger.info('detected conv prune dim: %s', self.conv_prune_dim) _logger.info('Dectected conv prune dim" %d', self.conv_prune_dim)
def fix_mask(self): def fix_mask(self):
"""
Fix the mask conflict before the mask inference for the layers that
has shape dependencies. This function should be called before the
mask inference of the 'speedup' module.
"""
""" """
Fix the mask conflict before the mask inference for the layers that Fix the mask conflict before the mask inference for the layers that
has shape dependencies. This function should be called before the has shape dependencies. This function should be called before the
...@@ -274,7 +229,12 @@ class ChannelMaskConflict(MaskFix): ...@@ -274,7 +229,12 @@ class ChannelMaskConflict(MaskFix):
if (channel_mask.sum() * (mask.numel() / mask.shape[self.conv_prune_dim])).item() != (mask > 0).sum().item(): if (channel_mask.sum() * (mask.numel() / mask.shape[self.conv_prune_dim])).item() != (mask > 0).sum().item():
fine_grained = True fine_grained = True
elif type(m).__name__ == 'Linear': elif type(m).__name__ == 'Linear':
channel_masks.append((mask.abs().sum(0) != 0).int()) if self.conv_prune_dim == 1:
channel_masks.append(
(mask.abs().sum(0) != 0).int())
else:
channel_masks.append(
(mask.abs().sum(1) != 0).int())
elif type(m).__name__ == 'BatchNorm2d': elif type(m).__name__ == 'BatchNorm2d':
channel_masks.append(mask.int()) channel_masks.append(mask.int())
elif type(m).__name__ == 'ConvTranspose2d': elif type(m).__name__ == 'ConvTranspose2d':
...@@ -293,9 +253,7 @@ class ChannelMaskConflict(MaskFix): ...@@ -293,9 +253,7 @@ class ChannelMaskConflict(MaskFix):
# no mask means not pruned, equivlent to full masks # no mask means not pruned, equivlent to full masks
channel_masks.append(None) channel_masks.append(None)
if fine_grained: if fine_grained:
_logger.info( _logger.info("Fine-grianed mask detected")
'fine-grained mask detected, skip solving conflict for this set: %s', dset)
continue
if all(x is None for x in channel_masks): if all(x is None for x in channel_masks):
continue continue
num_channels_list = [len(x) num_channels_list = [len(x)
...@@ -306,7 +264,8 @@ class ChannelMaskConflict(MaskFix): ...@@ -306,7 +264,8 @@ class ChannelMaskConflict(MaskFix):
for i, dim_mask in enumerate(channel_masks): for i, dim_mask in enumerate(channel_masks):
if dim_mask is None: if dim_mask is None:
channel_masks[i] = torch.ones(num_channels).int().to(device) channel_masks[i] = torch.ones(
num_channels).int().to(device)
# merge masks with 'or' # merge masks with 'or'
merged_channel_mask = channel_masks[0].clone() merged_channel_mask = channel_masks[0].clone()
...@@ -329,19 +288,22 @@ class ChannelMaskConflict(MaskFix): ...@@ -329,19 +288,22 @@ class ChannelMaskConflict(MaskFix):
else: else:
new_mask[:, merged_index, :, :] = 1. new_mask[:, merged_index, :, :] = 1.
elif type(m).__name__ == 'Linear': elif type(m).__name__ == 'Linear':
new_mask[:, merged_index] = 1. if self.conv_prune_dim == 0:
new_mask[merged_index, :] = 1
elif self.conv_prune_dim == 1:
new_mask[:, merged_index] = 1.
elif type(m).__name__ == 'BatchNorm2d': elif type(m).__name__ == 'BatchNorm2d':
new_mask = merged_channel_mask.type_as(orig_mask) new_mask = merged_channel_mask.type_as(orig_mask)
else: else:
raise RuntimeError( raise RuntimeError(
f'unsupported module type: {type(m).__name__}') f'unsupported module type: {type(m).__name__}')
self.masks[name]['weight'] = new_mask self.masks[name]['weight'] = new_mask
if 'bias' in self.masks[name] and self.masks[name]['bias'] is not None: if 'bias' in self.masks[name] and self.masks[name]['bias'] is not None:
if type(m).__name__ == 'Conv2d': if type(m).__name__ == 'Conv2d':
assert self.conv_prune_dim == 0 assert self.conv_prune_dim == 0
self.masks[name]['bias'] = merged_channel_mask.type_as( if self.conv_prune_dim == 0:
self.masks[name]['bias']) self.masks[name]['bias'] = merged_channel_mask.type_as(
self.masks[name]['bias'])
return self.masks return self.masks
...@@ -349,14 +311,12 @@ class ChannelMaskConflict(MaskFix): ...@@ -349,14 +311,12 @@ class ChannelMaskConflict(MaskFix):
def detect_mask_prune_dim(masks, model): def detect_mask_prune_dim(masks, model):
""" """
Detect how the masks of convolutional layers are pruned. Detect how the masks of convolutional layers are pruned.
Parameters Parameters
---------- ----------
masks: dict masks: dict
A dict object that stores the masks. A dict object that stores the masks.
model: nn.Module model: nn.Module
Model object which the mask can be applied on. Model object which the mask can be applied on.
Returns: Returns:
------- -------
How the masks of convolutional layers are pruned, this depends on pruning algorithms, it should How the masks of convolutional layers are pruned, this depends on pruning algorithms, it should
......
...@@ -3,18 +3,34 @@ ...@@ -3,18 +3,34 @@
import csv import csv
import logging import logging
import numpy as np
__all__ = ['ChannelDependency', 'GroupDependency',
'CatPaddingDependency', 'InputChannelDependency'] __all__ = ['ChannelDependency', 'GroupDependency', 'InputChannelDependency']
CONV_TYPE = 'aten::_convolution' CONV_TYPE = 'aten::_convolution'
ADD_TYPES = ['aten::add', 'aten::add_'] ADD_TYPES = ['aten::add', 'aten::add_']
MUL_TYPES = ['aten::mul', 'atem::mul_']
CAT_TYPE = 'aten::cat' CAT_TYPE = 'aten::cat'
logger = logging.getLogger('Shape_Dependency') logger = logging.getLogger('Shape_Dependency')
RESHAPE_OPS = [CAT_TYPE, 'aten::view', RESHAPE_OPS = [CAT_TYPE, 'aten::view',
'aten::reshape', 'aten::flatten', 'aten::mean'] 'aten::reshape', 'aten::flatten', 'aten::mean']
def lcm_list(L):
lcm = 1
for i in L:
lcm = np.lcm(lcm, i)
return lcm
def gcd_list(L):
gcd = L[0]
for i in L:
gcd = np.gcd(gcd, i)
return gcd
class Dependency: class Dependency:
def __init__(self, model=None, dummy_input=None, traced_model=None): def __init__(self, model=None, dummy_input=None, traced_model=None):
""" """
...@@ -38,6 +54,35 @@ class Dependency: ...@@ -38,6 +54,35 @@ class Dependency:
raise NotImplementedError raise NotImplementedError
def reshape_break_channel_dependency(op_node):
"""
The reshape operations such as (reshape, view, flatten) may break
the channel dependency. We need to check the input parameters of
these reshape operations to check if this reshape node will break
the channel dependency. However, it's complicated to analyze the the input
parameters for each reshape function and infer if it will break the channel
dependency. So currently, we just check if the input channel and the output
channel is the same, if so, then we can say the original reshape function
doesn't want to change the number of the channels, which means the channel
dependency is not broken. In contrast, the original reshap operation wants
to change the number of channels, so it breaks the channel dependency.
Parameters
----------
opnode: NodePyOP
A Op node of the graph.
Returns
-------
bool
If this operation will break the channel dependency.
"""
in_shape = op_node.auxiliary['in_shape']
out_shape = op_node.auxiliary['out_shape']
in_channel = in_shape[1]
out_channel = out_shape[1]
return in_channel != out_channel
class ChannelDependency(Dependency): class ChannelDependency(Dependency):
def __init__(self, model=None, dummy_input=None, traced_model=None): def __init__(self, model=None, dummy_input=None, traced_model=None):
""" """
...@@ -80,6 +125,9 @@ class ChannelDependency(Dependency): ...@@ -80,6 +125,9 @@ class ChannelDependency(Dependency):
# find the first met conv # find the first met conv
parent_layers.append(curnode.name) parent_layers.append(curnode.name)
continue continue
elif curnode.op_type in RESHAPE_OPS:
if reshape_break_channel_dependency(curnode):
continue
parents = self.graph.find_predecessors(curnode.unique_name) parents = self.graph.find_predecessors(curnode.unique_name)
parents = [self.graph.name_to_node[name] for name in parents] parents = [self.graph.name_to_node[name] for name in parents]
for parent in parents: for parent in parents:
...@@ -176,7 +224,7 @@ class ChannelDependency(Dependency): ...@@ -176,7 +224,7 @@ class ChannelDependency(Dependency):
d_sets = [] d_sets = []
visited = set() visited = set()
for node in self.graph.nodes_py.nodes_op: for node in self.graph.nodes_py.nodes_op:
if node.op_type != 'Conv2d' or node in visited: if (node.op_type != 'Conv2d' and node.op_type != 'Linear') or node in visited:
continue continue
tmp_set = set() tmp_set = set()
if node.name not in self.dependency: if node.name not in self.dependency:
...@@ -190,35 +238,6 @@ class ChannelDependency(Dependency): ...@@ -190,35 +238,6 @@ class ChannelDependency(Dependency):
return d_sets return d_sets
def reshape_break_channel_dependency(op_node):
"""
The reshape operations such as (reshape, view, flatten) may break
the channel dependency. We need to check the input parameters of
these reshape operations to check if this reshape node will break
the channel dependency. However, it's complicated to analyze the the input
parameters for each reshape function and infer if it will break the channel
dependency. So currently, we just check if the input channel and the output
channel is the same, if so, then we can say the original reshape function
doesn't want to change the number of the channels, which means the channel
dependency is not broken. In contrast, the original reshap operation wants
to change the number of channels, so it breaks the channel dependency.
Parameters
----------
opnode: NodePyOP
A Op node of the graph.
Returns
-------
bool
If this operation will break the channel dependency.
"""
in_shape = op_node.auxiliary['in_shape']
out_shape = op_node.auxiliary['out_shape']
in_channel = in_shape[1]
out_channel = out_shape[1]
return in_channel != out_channel
class InputChannelDependency(ChannelDependency): class InputChannelDependency(ChannelDependency):
""" """
Some pruners may prune the input channel of the convolutional Some pruners may prune the input channel of the convolutional
...@@ -295,67 +314,6 @@ class InputChannelDependency(ChannelDependency): ...@@ -295,67 +314,6 @@ class InputChannelDependency(ChannelDependency):
self.dependency[layer] = dependency_set self.dependency[layer] = dependency_set
class CatPaddingDependency(ChannelDependency):
def __init__(self, model=None, dummy_input=None, traced_model=None):
super(CatPaddingDependency, self).__init__(
model, dummy_input, traced_model)
def build_dependency(self):
"""
Build the cat padding dependencies.
If the output features of several layers are stitched together
by cat operation, then these layers have cat padding dependencies.
This is because when inferring the cat mask, we need all the input
masks for the cat operation. At this time we need to know the source
of all input vectors of a cat operation.
"""
for node in self.graph.nodes_py.nodes_op:
parent_layers = []
if node.op_type == CAT_TYPE:
parent_layers = self._get_parent_layers(node)
dependency_set = set(parent_layers)
# merge the dependencies
for parent in parent_layers:
if parent in self.dependency:
dependency_set.update(self.dependency[parent])
# save the dependencies
for _node in dependency_set:
self.dependency[_node] = dependency_set
@property
def dependency_sets(self):
d_sets = []
visited = set()
for nodename in self.dependency:
if nodename in visited:
continue
d_sets.append(self.dependency[nodename])
return d_sets
def export(self, filepath):
"""
Export the dependencies into a file.
In the output file, each line contains a set of layers
whose output features are stitched together by the cat
operation.
output example:
Dependency Set, Layers
set1, Conv1, Conv2
set2, Conv3, Conv4
"""
header = ['Dependency Set', 'Layers']
setid = 0
with open(filepath, 'w') as csvf:
csv_w = csv.writer(csvf, delimiter=',')
csv_w.writerow(header)
for layers in self.dependency_sets:
setid += 1
row = ['Set %d' % setid]
row.extend(list(layers))
csv_w.writerow(row)
class GroupDependency(Dependency): class GroupDependency(Dependency):
def __init__(self, model=None, dummy_input=None, traced_model=None): def __init__(self, model=None, dummy_input=None, traced_model=None):
""" """
...@@ -372,6 +330,7 @@ class GroupDependency(Dependency): ...@@ -372,6 +330,7 @@ class GroupDependency(Dependency):
if we alreay has the traced graph of the target model, we donnot if we alreay has the traced graph of the target model, we donnot
need to trace the model again. need to trace the model again.
""" """
self.min_groups = {}
super(GroupDependency, self).__init__(model, dummy_input, traced_model) super(GroupDependency, self).__init__(model, dummy_input, traced_model)
def _get_parent_convs(self, node): def _get_parent_convs(self, node):
...@@ -451,27 +410,33 @@ class GroupDependency(Dependency): ...@@ -451,27 +410,33 @@ class GroupDependency(Dependency):
key: the name of conv layers, value: the minimum value that the number of key: the name of conv layers, value: the minimum value that the number of
filters should be divisible to. filters should be divisible to.
""" """
self.groups = {}
for node in self.graph.nodes_py.nodes_op: for node in self.graph.nodes_py.nodes_op:
if node.op_type == 'Conv2d' or node.op_type == 'ConvTranspose2d': if node.op_type == 'Conv2d' or node.op_type == 'ConvTranspose2d':
group = self._get_conv_groups(node) group = self._get_conv_groups(node)
if node.name in self.groups:
if node.name in self.dependency:
# the conv layer whose group is larger than 1 will require that # the conv layer whose group is larger than 1 will require that
# it's number of output channel to be divisible by the number of group. # it's number of output channel to be divisible by the number of group.
self.dependency[node.name] = max( self.groups[node.name].append(group)
self.dependency[node.name], group)
else: else:
self.dependency[node.name] = group self.groups[node.name] = [group]
if group > 1: if group > 1:
# for the conv layer whose group is larger than 1, it will require the number # for the conv layer whose group is larger than 1, it will require the number
# of output channels of their parent conv layer to be divisible by group. # of output channels of their parent conv layer to be divisible by group.
parent_convs = self._get_parent_convs(node) parent_convs = self._get_parent_convs(node)
for parent in parent_convs: for parent in parent_convs:
if parent in self.dependency: if parent in self.groups:
self.dependency[parent] = max( self.groups[parent].append(group)
self.dependency[parent], group)
else: else:
self.dependency[parent] = group self.groups[parent] = [group]
for name in self.groups:
self.dependency[name] = lcm_list(self.groups[name])
if min(self.groups[name]) == gcd_list(self.groups[name]):
self.min_groups[name] = min(self.groups[name])
else:
self.min_groups[name] = 1
return self.dependency return self.dependency
def export(self, filepath): def export(self, filepath):
...@@ -501,3 +466,110 @@ class GroupDependency(Dependency): ...@@ -501,3 +466,110 @@ class GroupDependency(Dependency):
@property @property
def dependency_sets(self): def dependency_sets(self):
return self.dependency return self.dependency
class ReshapeDependency(Dependency):
def __init__(self, model=None, dummy_input=None, traced_model=None):
"""
Some model may have the view/reshape functions, such functions may have fixed parameters
and cannot be replaced at all. Therefore, these functions may have some constraints on
their input shapes. In this class, we find the direct input conv/linear layers of these
reshape functions. If you get the shape conflict when run the forward inference on the
speeduped model, please try remove these layers from the pruner config list and try again.
Parameters
----------
model : torch.nn.Module
The model to be analyzed.
data : torch.Tensor
The example input data to trace the network architecture.
traced_model : torch._C.Graph
if we alreay has the traced graph of the target model, we donnot
need to trace the model again.
"""
super(ReshapeDependency, self).__init__(
model, dummy_input, traced_model)
def _get_parent_layers(self, node):
"""
Find the nearest father conv layers for the target node.
Parameters
---------
node : torch._C.Node
target node.
Returns
-------
parent_layers: list
nearest father conv/linear layers for the target worknode.
"""
parent_layers = []
queue = []
queue.append(node)
while queue:
curnode = queue.pop(0)
if curnode.op_type == 'Conv2d' or curnode.op_type == 'Linear' or curnode.op_type == 'ConvTranspose2d':
# find the first met conv
parent_layers.append(curnode.name)
continue
parents = self.graph.find_predecessors(curnode.unique_name)
parents = [self.graph.name_to_node[name] for name in parents]
for parent in parents:
queue.append(parent)
return parent_layers
def build_dependency(self):
"""
Build the channel dependency for the conv layers
in the model.
"""
# unpack the tuple/list manually before analyze the
# channel dependency
self.graph.unpack_manually()
for node in self.graph.nodes_py.nodes_op:
parent_layers = []
# find the node that contains aten::add
# or aten::cat operations
if node.op_type in ['aten::view', 'aten::reshape']:
logger.info('Detect reshape-like functions: %s', node.op_type)
parent_layers = self._get_parent_layers(node)
print('Parent layers', parent_layers)
self.dependency[node.unique_name] = parent_layers
def export(self, filepath):
"""
export the reshape dependencies as a csv file.
Output example:
Reshape OP, Dependent Layers
model.view.1,layer1.1.conv2,layer1.0.conv2,conv1
model.mean.1,layer1.0.conv1
model.reshape.1,layer1.1.conv1
"""
header = ['Reshape OP', 'Dependent Layers']
with open(filepath, 'w') as csvf:
csv_w = csv.writer(csvf, delimiter=',')
csv_w.writerow(header)
for reshape_op in self.dependency:
row = [reshape_op].extend(self.dependency[reshape_op])
csv_w.writerow(row)
@property
def dependency_sets(self):
"""
Get the list of the dependency set.
Returns
-------
dependency_sets : list
list of the dependency sets. For example,
[set(['conv1', 'conv2']), set(['conv3', 'conv4'])]
"""
d_sets = []
for reshape_node in self.dependency:
d_sets.extend(self.dependency[reshape_node])
d_sets = list(set(d_sets))
return d_sets
# Copyright (c) Microsoft Corporation. # Copyright (c) Microsoft Corporation.
# Licensed under the MIT license. # Licensed under the MIT license.
import torch
from .shape_dependency import ReshapeDependency
torch_float_dtype = [torch.float, torch.float16, torch.float32, torch.float64, torch.half, torch.double]
torch_integer_dtype = [torch.uint8, torch.int16, torch.short, torch.int32, torch.long, torch.bool]
def get_module_by_name(model, module_name): def get_module_by_name(model, module_name):
""" """
...@@ -28,3 +33,50 @@ def get_module_by_name(model, module_name): ...@@ -28,3 +33,50 @@ def get_module_by_name(model, module_name):
return model, leaf_module return model, leaf_module
else: else:
return None, None return None, None
def rand_like_with_shape(shape, ori_t):
"""
Return a new random tensor like the original
tensor.
"""
assert isinstance(ori_t, torch.Tensor)
device = ori_t.device
dtype = ori_t.dtype
require_grad = ori_t.requires_grad
lower_bound = torch.min(ori_t)
higher_bound = torch.max(ori_t)
if dtype in [torch.uint8, torch.int16, torch.short, torch.int16, torch.long, torch.bool]:
return torch.randint(lower_bound, higher_bound+1, shape, dtype=dtype, device=device)
else:
return torch.rand(shape, dtype=dtype, device=device, requires_grad=require_grad)
def randomize_tensor(tensor, start=1, end=100):
"""
Randomize the target tensor according to the given
range.
"""
assert isinstance(tensor, torch.Tensor)
if tensor.dtype in torch_integer_dtype:
# integer tensor can only be randomized by the torch.randint
# torch.randint(int(start), int(end), tensor.size(), out=tensor.data, dtype=tensor.dtype)
pass
else:
# we can use nn.init.uniform_ to randomize this tensor
# Note: the tensor that with integer type cannot be randomize
# with nn.init.uniform_
torch.nn.init.uniform_(tensor.data, start, end)
def not_safe_to_prune(model, dummy_input):
"""
Get the layers that are safe to prune(will not bring the shape conflict).
Parameters
----------
model: torch.nn.Module
The target model to prune.
dummy_input: torch.Tensor/list of torch.Tensor/tuple of Tensor
"""
reshape_dset = ReshapeDependency(model, dummy_input)
return reshape_dset.dependency_sets
\ No newline at end of file
...@@ -116,7 +116,7 @@ class AnalysisUtilsTest(TestCase): ...@@ -116,7 +116,7 @@ class AnalysisUtilsTest(TestCase):
pruner.export_model(ck_file, mask_file) pruner.export_model(ck_file, mask_file)
pruner._unwrap_model() pruner._unwrap_model()
# Fix the mask conflict # Fix the mask conflict
fixed_mask, _ = fix_mask_conflict(mask_file, net, dummy_input) fixed_mask = fix_mask_conflict(mask_file, net, dummy_input)
# use the channel dependency groud truth to check if # use the channel dependency groud truth to check if
# fix the mask conflict successfully # fix the mask conflict successfully
......
# Copyright (c) Microsoft Corporation. # Copyright (c) Microsoft Corporation.
# Licensed under the MIT license. # Licensed under the MIT license.
import logging
import os import os
import gc
import psutil import psutil
import sys import sys
import numpy as np import numpy as np
...@@ -9,18 +11,20 @@ import torch ...@@ -9,18 +11,20 @@ import torch
import torchvision.models as models import torchvision.models as models
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torchvision.models.vgg import vgg16 from torchvision.models.vgg import vgg16, vgg11
from torchvision.models.resnet import resnet18 from torchvision.models.resnet import resnet18
from torchvision.models.mobilenet import mobilenet_v2
import unittest import unittest
from unittest import TestCase, main from unittest import TestCase, main
from nni.compression.pytorch import ModelSpeedup, apply_compression_results from nni.compression.pytorch import ModelSpeedup, apply_compression_results
from nni.algorithms.compression.pytorch.pruning import L1FilterPruner from nni.algorithms.compression.pytorch.pruning import L1FilterPruner, LevelPruner
from nni.algorithms.compression.pytorch.pruning.weight_masker import WeightMasker from nni.algorithms.compression.pytorch.pruning.weight_masker import WeightMasker
from nni.algorithms.compression.pytorch.pruning.dependency_aware_pruner import DependencyAwarePruner from nni.algorithms.compression.pytorch.pruning.dependency_aware_pruner import DependencyAwarePruner
torch.manual_seed(0) torch.manual_seed(0)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
BATCH_SIZE = 2 BATCH_SIZE = 2
# the relative distance # the relative distance
RELATIVE_THRESHOLD = 0.01 RELATIVE_THRESHOLD = 0.01
...@@ -105,6 +109,55 @@ class TransposeModel(torch.nn.Module): ...@@ -105,6 +109,55 @@ class TransposeModel(torch.nn.Module):
return x return x
class TupleUnpack_backbone(nn.Module):
def __init__(self, width):
super(TupleUnpack_backbone, self).__init__()
self.model_backbone = mobilenet_v2(
pretrained=False, width_mult=width, num_classes=3)
def forward(self, x):
x1 = self.model_backbone.features[:7](x)
x2 = self.model_backbone.features[7:14](x1)
x3 = self.model_backbone.features[14:18](x2)
return [x1, x2, x3]
class TupleUnpack_FPN(nn.Module):
def __init__(self):
super(TupleUnpack_FPN, self).__init__()
self.conv1 = nn.Conv2d(32, 48, kernel_size=(
1, 1), stride=(1, 1), bias=False)
self.conv2 = nn.Conv2d(96, 48, kernel_size=(
1, 1), stride=(1, 1), bias=False)
self.conv3 = nn.Conv2d(320, 48, kernel_size=(
1, 1), stride=(1, 1), bias=False)
# self.init_weights()
def forward(self, inputs):
"""Forward function."""
laterals = []
laterals.append(self.conv1(inputs[0])) # inputs[0]==x1
laterals.append(self.conv2(inputs[1])) # inputs[1]==x2
laterals.append(self.conv3(inputs[2])) # inputs[2]==x3
return laterals
class TupleUnpack_Model(nn.Module):
def __init__(self):
super(TupleUnpack_Model, self).__init__()
self.backbone = TupleUnpack_backbone(1.0)
self.fpn = TupleUnpack_FPN()
def forward(self, x):
x1 = self.backbone(x)
out = self.fpn(x1)
return out
dummy_input = torch.randn(2, 1, 28, 28) dummy_input = torch.randn(2, 1, 28, 28)
SPARSITY = 0.5 SPARSITY = 0.5
MODEL_FILE, MASK_FILE = './11_model.pth', './l1_mask.pth' MODEL_FILE, MASK_FILE = './11_model.pth', './l1_mask.pth'
...@@ -129,6 +182,7 @@ def generate_random_sparsity(model): ...@@ -129,6 +182,7 @@ def generate_random_sparsity(model):
'sparsity': sparsity}) 'sparsity': sparsity})
return cfg_list return cfg_list
def generate_random_sparsity_v2(model): def generate_random_sparsity_v2(model):
""" """
Only select 50% layers to prune. Only select 50% layers to prune.
...@@ -139,9 +193,10 @@ def generate_random_sparsity_v2(model): ...@@ -139,9 +193,10 @@ def generate_random_sparsity_v2(model):
if np.random.uniform(0, 1.0) > 0.5: if np.random.uniform(0, 1.0) > 0.5:
sparsity = np.random.uniform(0.5, 0.99) sparsity = np.random.uniform(0.5, 0.99)
cfg_list.append({'op_types': ['Conv2d'], 'op_names': [name], cfg_list.append({'op_types': ['Conv2d'], 'op_names': [name],
'sparsity': sparsity}) 'sparsity': sparsity})
return cfg_list return cfg_list
def zero_bn_bias(model): def zero_bn_bias(model):
with torch.no_grad(): with torch.no_grad():
for name, module in model.named_modules(): for name, module in model.named_modules():
...@@ -231,19 +286,6 @@ def channel_prune(model): ...@@ -231,19 +286,6 @@ def channel_prune(model):
class SpeedupTestCase(TestCase): class SpeedupTestCase(TestCase):
def test_speedup_vgg16(self):
prune_model_l1(vgg16())
model = vgg16()
model.train()
ms = ModelSpeedup(model, torch.randn(2, 3, 32, 32), MASK_FILE)
ms.speedup_model()
orig_model = vgg16()
assert model.training
assert model.features[2].out_channels == int(
orig_model.features[2].out_channels * SPARSITY)
assert model.classifier[0].in_features == int(
orig_model.classifier[0].in_features * SPARSITY)
def test_speedup_bigmodel(self): def test_speedup_bigmodel(self):
prune_model_l1(BigModel()) prune_model_l1(BigModel())
...@@ -253,7 +295,7 @@ class SpeedupTestCase(TestCase): ...@@ -253,7 +295,7 @@ class SpeedupTestCase(TestCase):
mask_out = model(dummy_input) mask_out = model(dummy_input)
model.train() model.train()
ms = ModelSpeedup(model, dummy_input, MASK_FILE) ms = ModelSpeedup(model, dummy_input, MASK_FILE, confidence=2)
ms.speedup_model() ms.speedup_model()
assert model.training assert model.training
...@@ -289,7 +331,7 @@ class SpeedupTestCase(TestCase): ...@@ -289,7 +331,7 @@ class SpeedupTestCase(TestCase):
new_model = TransposeModel() new_model = TransposeModel()
state_dict = torch.load(MODEL_FILE) state_dict = torch.load(MODEL_FILE)
new_model.load_state_dict(state_dict) new_model.load_state_dict(state_dict)
ms = ModelSpeedup(new_model, dummy_input, MASK_FILE) ms = ModelSpeedup(new_model, dummy_input, MASK_FILE, confidence=2)
ms.speedup_model() ms.speedup_model()
zero_bn_bias(ori_model) zero_bn_bias(ori_model)
zero_bn_bias(new_model) zero_bn_bias(new_model)
...@@ -297,26 +339,34 @@ class SpeedupTestCase(TestCase): ...@@ -297,26 +339,34 @@ class SpeedupTestCase(TestCase):
new_out = new_model(dummy_input) new_out = new_model(dummy_input)
ori_sum = torch.sum(ori_out) ori_sum = torch.sum(ori_out)
speeded_sum = torch.sum(new_out) speeded_sum = torch.sum(new_out)
print('Tanspose Speedup Test: ori_sum={} speedup_sum={}'.format(ori_sum, speeded_sum)) print('Tanspose Speedup Test: ori_sum={} speedup_sum={}'.format(
ori_sum, speeded_sum))
assert (abs(ori_sum - speeded_sum) / abs(ori_sum) < RELATIVE_THRESHOLD) or \ assert (abs(ori_sum - speeded_sum) / abs(ori_sum) < RELATIVE_THRESHOLD) or \
(abs(ori_sum - speeded_sum) < ABSOLUTE_THRESHOLD) (abs(ori_sum - speeded_sum) < ABSOLUTE_THRESHOLD)
# FIXME: This test case might fail randomly, no idea why
# Example: https://msrasrg.visualstudio.com/NNIOpenSource/_build/results?buildId=16282
def test_speedup_integration(self): def test_speedup_integration_small(self):
# skip this test on windows(7GB mem available) due to memory limit model_list = ['resnet18', 'mobilenet_v2', 'alexnet']
# Note: hack trick, may be updated in the future self.speedup_integration(model_list)
if 'win' in sys.platform or 'Win'in sys.platform:
print('Skip test_speedup_integration on windows due to memory limit!') def test_speedup_integration_big(self):
model_list = ['vgg11', 'vgg16', 'resnet34', 'squeezenet1_1',
'densenet121', 'resnet50', 'wide_resnet50_2']
mem_info = psutil.virtual_memory()
ava_gb = mem_info.available/1024.0/1024/1024
print('Avaliable memory size: %.2f GB' % ava_gb)
if ava_gb < 8.0:
# memory size is too small that we may run into an OOM exception
# Skip this test in the pipeline test due to memory limitation
return return
self.speedup_integration(model_list)
def speedup_integration(self, model_list, speedup_cfg=None):
Gen_cfg_funcs = [generate_random_sparsity, generate_random_sparsity_v2] Gen_cfg_funcs = [generate_random_sparsity, generate_random_sparsity_v2]
for model_name in ['resnet18', 'mobilenet_v2', 'squeezenet1_1', 'densenet121' , 'densenet169', # for model_name in ['vgg16', 'resnet18', 'mobilenet_v2', 'squeezenet1_1', 'densenet121',
# 'inception_v3' inception is too large and may fail the pipeline # # 'inception_v3' inception is too large and may fail the pipeline
'resnet50']: # 'resnet50']:
for model_name in model_list:
for gen_cfg_func in Gen_cfg_funcs: for gen_cfg_func in Gen_cfg_funcs:
kwargs = { kwargs = {
'pretrained': True 'pretrained': True
...@@ -334,7 +384,10 @@ class SpeedupTestCase(TestCase): ...@@ -334,7 +384,10 @@ class SpeedupTestCase(TestCase):
speedup_model.eval() speedup_model.eval()
# random generate the prune config for the pruner # random generate the prune config for the pruner
cfgs = gen_cfg_func(net) cfgs = gen_cfg_func(net)
print("Testing {} with compression config \n {}".format(model_name, cfgs)) print("Testing {} with compression config \n {}".format(
model_name, cfgs))
if len(cfgs) == 0:
continue
pruner = L1FilterPruner(net, cfgs) pruner = L1FilterPruner(net, cfgs)
pruner.compress() pruner.compress()
pruner.export_model(MODEL_FILE, MASK_FILE) pruner.export_model(MODEL_FILE, MASK_FILE)
...@@ -345,7 +398,10 @@ class SpeedupTestCase(TestCase): ...@@ -345,7 +398,10 @@ class SpeedupTestCase(TestCase):
zero_bn_bias(speedup_model) zero_bn_bias(speedup_model)
data = torch.ones(BATCH_SIZE, 3, 128, 128).to(device) data = torch.ones(BATCH_SIZE, 3, 128, 128).to(device)
ms = ModelSpeedup(speedup_model, data, MASK_FILE) if speedup_cfg is None:
speedup_cfg = {}
ms = ModelSpeedup(speedup_model, data,
MASK_FILE, confidence=2, **speedup_cfg)
ms.speedup_model() ms.speedup_model()
speedup_model.eval() speedup_model.eval()
...@@ -355,12 +411,13 @@ class SpeedupTestCase(TestCase): ...@@ -355,12 +411,13 @@ class SpeedupTestCase(TestCase):
ori_sum = torch.sum(ori_out).item() ori_sum = torch.sum(ori_out).item()
speeded_sum = torch.sum(speeded_out).item() speeded_sum = torch.sum(speeded_out).item()
print('Sum of the output of %s (before speedup):' % print('Sum of the output of %s (before speedup):' %
model_name, ori_sum) model_name, ori_sum)
print('Sum of the output of %s (after speedup):' % print('Sum of the output of %s (after speedup):' %
model_name, speeded_sum) model_name, speeded_sum)
assert (abs(ori_sum - speeded_sum) / abs(ori_sum) < RELATIVE_THRESHOLD) or \ assert (abs(ori_sum - speeded_sum) / abs(ori_sum) < RELATIVE_THRESHOLD) or \
(abs(ori_sum - speeded_sum) < ABSOLUTE_THRESHOLD) (abs(ori_sum - speeded_sum) < ABSOLUTE_THRESHOLD)
print("Collecting Garbage")
gc.collect(2)
def test_channel_prune(self): def test_channel_prune(self):
orig_net = resnet18(num_classes=10).to(device) orig_net = resnet18(num_classes=10).to(device)
...@@ -378,7 +435,7 @@ class SpeedupTestCase(TestCase): ...@@ -378,7 +435,7 @@ class SpeedupTestCase(TestCase):
net.eval() net.eval()
data = torch.randn(BATCH_SIZE, 3, 128, 128).to(device) data = torch.randn(BATCH_SIZE, 3, 128, 128).to(device)
ms = ModelSpeedup(net, data, MASK_FILE) ms = ModelSpeedup(net, data, MASK_FILE, confidence=2)
ms.speedup_model() ms.speedup_model()
ms.bound_model(data) ms.bound_model(data)
...@@ -391,11 +448,56 @@ class SpeedupTestCase(TestCase): ...@@ -391,11 +448,56 @@ class SpeedupTestCase(TestCase):
assert (abs(ori_sum - speeded_sum) / abs(ori_sum) < RELATIVE_THRESHOLD) or \ assert (abs(ori_sum - speeded_sum) / abs(ori_sum) < RELATIVE_THRESHOLD) or \
(abs(ori_sum - speeded_sum) < ABSOLUTE_THRESHOLD) (abs(ori_sum - speeded_sum) < ABSOLUTE_THRESHOLD)
def test_speedup_tupleunpack(self):
"""This test is reported in issue3645"""
model = TupleUnpack_Model()
cfg_list = [{'op_types': ['Conv2d'], 'sparsity':0.5}]
dummy_input = torch.rand(2, 3, 224, 224)
pruner = L1FilterPruner(model, cfg_list)
pruner.compress()
model(dummy_input)
pruner.export_model(MODEL_FILE, MASK_FILE)
ms = ModelSpeedup(model, dummy_input, MASK_FILE, confidence=2)
ms.speedup_model()
def test_finegrained_speedup(self):
""" Test the speedup on the fine-grained sparsity"""
class MLP(nn.Module):
def __init__(self):
super(MLP, self).__init__()
self.fc1 = nn.Linear(1024, 1024)
self.fc2 = nn.Linear(1024, 1024)
self.fc3 = nn.Linear(1024, 512)
self.fc4 = nn.Linear(512, 10)
def forward(self, x):
x = x.view(-1, 1024)
x = self.fc1(x)
x = self.fc2(x)
x = self.fc3(x)
x = self.fc4(x)
return x
model = MLP().to(device)
dummy_input = torch.rand(16, 1, 32, 32).to(device)
cfg_list = [{'op_types': ['Linear'], 'sparsity':0.99}]
pruner = LevelPruner(model, cfg_list)
pruner.compress()
print('Original Arch')
print(model)
pruner.export_model(MODEL_FILE, MASK_FILE)
pruner._unwrap_model()
ms = ModelSpeedup(model, dummy_input, MASK_FILE, confidence=4)
ms.speedup_model()
print("Fine-grained speeduped model")
print(model)
def tearDown(self): def tearDown(self):
if os.path.exists(MODEL_FILE): if os.path.exists(MODEL_FILE):
os.remove(MODEL_FILE) os.remove(MODEL_FILE)
if os.path.exists(MASK_FILE): if os.path.exists(MASK_FILE):
os.remove(MASK_FILE) os.remove(MASK_FILE)
# GC to release memory
gc.collect(2)
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment