Unverified Commit 403195f0 authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

Merge branch 'master' into nn-meter

parents 99aa8226 a7278d2d
......@@ -5,7 +5,7 @@ import logging
import copy
import torch
from schema import Schema, And, Or, Optional
from nni.compression.pytorch.utils.config_validation import CompressorSchema
from nni.compression.pytorch.utils.config_validation import QuantizerSchema
from nni.compression.pytorch.compressor import Quantizer, QuantForward, QuantGrad, QuantType
__all__ = ['NaiveQuantizer', 'QAT_Quantizer', 'DoReFaQuantizer', 'BNNQuantizer', 'LsqQuantizer']
......@@ -22,11 +22,12 @@ class NaiveQuantizer(Quantizer):
self.layer_scale = {}
def validate_config(self, model, config_list):
schema = CompressorSchema([{
schema = QuantizerSchema([{
Optional('quant_types'): ['weight'],
Optional('quant_bits'): Or(8, {'weight': 8}),
Optional('op_types'): [str],
Optional('op_names'): [str]
Optional('op_names'): [str],
Optional('exclude'): bool
}], model, logger)
schema.validate(config_list)
......@@ -183,7 +184,7 @@ class QAT_Quantizer(Quantizer):
config_list : list of dict
List of configurations
"""
schema = CompressorSchema([{
schema = QuantizerSchema([{
Optional('quant_types'): Schema([lambda x: x in ['weight', 'output']]),
Optional('quant_bits'): Or(And(int, lambda n: 0 < n < 32), Schema({
Optional('weight'): And(int, lambda n: 0 < n < 32),
......@@ -191,7 +192,8 @@ class QAT_Quantizer(Quantizer):
})),
Optional('quant_start_step'): And(int, lambda n: n >= 0),
Optional('op_types'): [str],
Optional('op_names'): [str]
Optional('op_names'): [str],
Optional('exclude'): bool
}], model, logger)
schema.validate(config_list)
......@@ -386,13 +388,14 @@ class DoReFaQuantizer(Quantizer):
config_list : list of dict
List of configurations
"""
schema = CompressorSchema([{
schema = QuantizerSchema([{
Optional('quant_types'): Schema([lambda x: x in ['weight']]),
Optional('quant_bits'): Or(And(int, lambda n: 0 < n < 32), Schema({
Optional('weight'): And(int, lambda n: 0 < n < 32)
})),
Optional('op_types'): [str],
Optional('op_names'): [str]
Optional('op_names'): [str],
Optional('exclude'): bool
}], model, logger)
schema.validate(config_list)
......@@ -493,14 +496,15 @@ class BNNQuantizer(Quantizer):
config_list : list of dict
List of configurations
"""
schema = CompressorSchema([{
schema = QuantizerSchema([{
Optional('quant_types'): Schema([lambda x: x in ['weight', 'output']]),
Optional('quant_bits'): Or(And(int, lambda n: 0 < n < 32), Schema({
Optional('weight'): And(int, lambda n: 0 < n < 32),
Optional('output'): And(int, lambda n: 0 < n < 32),
})),
Optional('op_types'): [str],
Optional('op_names'): [str]
Optional('op_names'): [str],
Optional('exclude'): bool
}], model, logger)
schema.validate(config_list)
......
......@@ -71,7 +71,7 @@ class CurvefittingAssessor(Assessor):
else:
self.set_best_performance = True
self.completed_best_performance = self.trial_history[-1]
logger.info('Updated complted best performance, trial job id: %s', trial_job_id)
logger.info('Updated completed best performance, trial job id: %s', trial_job_id)
else:
logger.info('No need to update, trial job id: %s', trial_job_id)
......
......@@ -71,7 +71,11 @@ class TorchGraph:
def _trace(self, model, dummy_input):
training = model.training
model.eval()
self.trace = torch.jit.trace(model, dummy_input)
kw_args = {}
if torch.__version__ >= '1.6.0':
# only pytorch with version greater than 1.6.0 has the strict option
kw_args['strict'] = False
self.trace = torch.jit.trace(model, dummy_input, **kw_args)
torch._C._jit_pass_inline(self.trace.graph)
model.train(training)
......@@ -247,6 +251,7 @@ class TorchModuleGraph(TorchGraph):
def __init__(self, model=None, dummy_input=None, traced_model=None):
super().__init__(model, dummy_input, traced_model)
self.global_count = 0
self.reused_module = set()
self.name_to_node, self.input_to_node, self.output_to_node = self._build_graph()
self._extract_auxiliary_info()
......@@ -390,9 +395,12 @@ class TorchModuleGraph(TorchGraph):
outputs.append(output_name)
else:
outputs.append(output_name)
unique_outputs = list(set(outputs))
# remove the dumplicated output names
unique_outputs.sort(key=outputs.index)
nodepy = NodePyGroup(node_name, unique_name, module_type, op_type,
node_group, inputs=list(inputs), outputs=list(outputs))
node_group, inputs=list(inputs), outputs=unique_outputs)
return nodepy
def _extract_cat_info(self, node_group, cpp_node):
......@@ -724,6 +732,8 @@ class TorchModuleGraph(TorchGraph):
unique_name = module_name
if use_count > 0:
unique_name = module_name + '.%d' % use_count
self.reused_module.add(unique_name)
self.reused_module.add(module_name)
node_group = self._expand_module_node(
node, module_name, unique_name, module_to_type[module_name],
node_cpps, input_to_node, output_to_node, 'module')
......
......@@ -12,7 +12,8 @@ from . import calibrator as calibrator
from . import trt_pycuda as common
from .backend import BaseModelSpeedup
# TRT_LOGGER = trt.Logger(trt.Logger.VERBOSE)
TRT8 = 8
TRT7 = 7
TRT_LOGGER = trt.Logger()
logger = logging.getLogger(__name__)
......@@ -23,7 +24,7 @@ class CalibrateType:
MINMAX = trt.CalibrationAlgoType.MINMAX_CALIBRATION
Precision_Dict = {
8: trt.float32,
8: trt.int8,
16: trt.float16,
32: trt.float32
}
......@@ -120,22 +121,43 @@ def build_engine(model_file, config=None, extra_layer_bit=32, strict_datatype=Fa
An ICudaEngine for executing inference on a built network
"""
with trt.Builder(TRT_LOGGER) as builder, builder.create_network(common.EXPLICIT_BATCH) as network, \
trt.OnnxParser(network, TRT_LOGGER) as parser:
trt.OnnxParser(network, TRT_LOGGER) as parser, builder.create_builder_config() as trt_config:
# Attention that, builder should be set to 1 because of the implementation of allocate_buffer
trt_version = int(trt.__version__[0])
assert trt_version == TRT8 or trt_version == TRT7, "Version of TensorRT is too old, please \
update TensorRT to version >= 7.0"
if trt_version == TRT7:
logger.warning("TensorRT7 is deprecated and may be removed in the following release.")
builder.max_batch_size = 1
builder.max_workspace_size = common.GiB(4)
if trt_version == TRT8:
trt_config.max_workspace_size = common.GiB(4)
else:
builder.max_workspace_size = common.GiB(4)
if extra_layer_bit == 32 and config is None:
pass
elif extra_layer_bit == 16 and config is None:
builder.fp16_mode = True
if trt_version == TRT8:
trt_config.set_flag(trt.BuilderFlag.FP16)
else:
builder.fp16_mode = True
elif extra_layer_bit == 8 and config is None:
# entire model in 8bit mode
builder.int8_mode = True
if trt_version == TRT8:
trt_config.set_flag(trt.BuilderFlag.INT8)
else:
builder.int8_mode = True
else:
builder.int8_mode = True
builder.fp16_mode = True
builder.strict_type_constraints = strict_datatype
if trt_version == TRT8:
trt_config.set_flag(trt.BuilderFlag.INT8)
trt_config.set_flag(trt.BuilderFlag.FP16)
if strict_datatype:
trt_config.set_flag(trt.BuilderFlag.STRICT_TYPES)
else:
builder.int8_mode = True
builder.fp16_mode = True
builder.strict_type_constraints = strict_datatype
valid_config(config)
......@@ -148,7 +170,10 @@ def build_engine(model_file, config=None, extra_layer_bit=32, strict_datatype=Fa
return None
if calib is not None:
builder.int8_calibrator = calib
if trt_version == TRT8:
trt_config.int8_calibrator = calib
else:
builder.int8_calibrator = calib
# This design may not be correct if output more than one
for i in range(network.num_layers):
if config is None:
......@@ -196,7 +221,10 @@ def build_engine(model_file, config=None, extra_layer_bit=32, strict_datatype=Fa
out_tensor.dynamic_range = (tracked_min_activation, tracked_max_activation)
# Build engine and do int8 calibration.
engine = builder.build_cuda_engine(network)
if trt_version == TRT8:
engine = builder.build_engine(network, trt_config)
else:
engine.builder.build_cuda_engine(network)
return engine
class ModelSpeedupTensorRT(BaseModelSpeedup):
......
......@@ -3,196 +3,394 @@
import logging
import torch
from .infer_shape import ModuleMasks
import torch.nn as nn
_logger = logging.getLogger(__name__)
replace_module = {
'BatchNorm2d': lambda module, mask: replace_batchnorm2d(module, mask),
'Conv2d': lambda module, mask: replace_conv2d(module, mask),
'ConvTranspose2d': lambda module, mask: replace_convtranspose2d(module, mask),
'MaxPool2d': lambda module, mask: no_replace(module, mask),
'AvgPool2d': lambda module, mask: no_replace(module, mask),
'AdaptiveAvgPool2d': lambda module, mask: no_replace(module, mask),
'ReLU': lambda module, mask: no_replace(module, mask),
'ReLU6': lambda module, mask: no_replace(module, mask),
'Sigmoid': lambda module, mask: no_replace(module, mask),
'Linear': lambda module, mask: replace_linear(module, mask),
'Dropout': lambda module, mask: no_replace(module, mask),
'Dropout2d': lambda module, mask: no_replace(module, mask),
'Dropout3d': lambda module, mask: no_replace(module, mask)
'BatchNorm2d': lambda module, masks: replace_batchnorm2d(module, masks),
'BatchNorm1d': lambda module, masks: replace_batchnorm1d(module, masks),
'Conv2d': lambda module, masks: replace_conv2d(module, masks),
'Linear': lambda module, masks: replace_linear(module, masks),
'MaxPool2d': lambda module, masks: no_replace(module, masks),
'AvgPool2d': lambda module, masks: no_replace(module, masks),
'AdaptiveAvgPool2d': lambda module, masks: no_replace(module, masks),
'ReLU': lambda module, masks: no_replace(module, masks),
'ReLU6': lambda module, masks: no_replace(module, masks),
'LeakyReLU': lambda module, masks: no_replace(module, masks),
'ELU': lambda module, masks: no_replace(module, masks),
'Hardtanh': lambda module, masks: no_replace(module, masks),
'Hardsigmoid': lambda module, masks: no_replace(module, masks),
'LogSigmoid': lambda module, masks: no_replace(module, masks),
'PReLU': lambda module, masks: replace_prelu(module, masks),
'RReLU': lambda module, masks: no_replace(module, masks),
'SELU': lambda module, masks: no_replace(module, masks),
'CELU': lambda module, masks: no_replace(module, masks),
'GELU': lambda module, masks: no_replace(module, masks),
'Sigmoid': lambda module, masks: no_replace(module, masks),
'SiLU': lambda module, masks: no_replace(module, masks),
'Mish': lambda module, masks: no_replace(module, masks),
'Tanh': lambda module, masks: no_replace(module, masks),
'Softplus': lambda module, masks: no_replace(module, masks),
'Softshrink': lambda module, masks: no_replace(module, masks),
'Softmax': lambda module, masks: no_replace(module, masks),
'Tanhshrink': lambda module, masks: no_replace(module, masks),
'Dropout': lambda module, masks: no_replace(module, masks),
'Dropout2d': lambda module, masks: no_replace(module, masks),
'Dropout3d': lambda module, masks: no_replace(module, masks),
'Upsample': lambda module, masks: no_replace(module, masks),
'LayerNorm': lambda module, masks: replace_layernorm(module, masks),
'ConvTranspose2d': lambda module, masks: replace_convtranspose2d(module, masks)
}
def no_replace(module, mask):
def convert_to_coarse_mask(t_mask, dim):
"""
Convert the mask tensor to the coarse-grained mask tensor.
Parameters
---------
t_mask: torch.Tensor
The tensor only have 1s and 0s, 0 indicates this value is masked
and 1 indicates the corresponding value is not masked.
dim: int
Try to reduce the mask tensor on this dimension.
Returns
-------
indexes: torch.Tensor
The indexes of the sparsity that can be structurally removed.
remained_indexes: torch.Tensor
The indexes of values that need to be remained.
"""
assert isinstance(t_mask, torch.Tensor)
shape = list(t_mask.size())
n_dims = len(shape)
dim_list = list(range(n_dims))
# try to reduce the mask from the dim-th dimension
dim_list.remove(dim)
t_merged = torch.sum(t_mask, dim_list)
assert t_merged.size(0) == shape[dim]
all_pruned = t_merged == 0
need_remain = t_merged != 0
# return the indexes of the sparsity that can be removed
indexes = torch.nonzero(all_pruned, as_tuple=True)[0]
remained_indexes = torch.nonzero(need_remain, as_tuple=True)[0]
return indexes, remained_indexes
def no_replace(module, masks):
"""
No need to replace
"""
_logger.debug("no need to replace")
return module
def replace_prelu(prelu, masks):
"""
Parameters
----------
module : torch.nn.PReLU
The prelu module to be replace
masks : tuple of masks
The input/output/weight masks of the target module
def replace_linear(linear, mask):
Returns
-------
torch.nn.PReLU
The new prelu module
"""
in_masks, output_mask, weight_mask = masks
assert len(in_masks) == 1
assert isinstance(output_mask, torch.Tensor)
in_mask = in_masks[0]
weight_mask = weight_mask['weight']
pruned_in, remained_in = convert_to_coarse_mask(in_mask, 1)
pruned_out, remained_out = convert_to_coarse_mask(output_mask, 1)
n_remained_in = weight_mask.size(0) - pruned_in.size(0)
n_remained_out = weight_mask.size(0) - pruned_out.size(0)
remained_in, remained_out = remained_in.to(
prelu.weight.device), remained_out.to(prelu.weight.device)
assert n_remained_in == n_remained_out
if n_remained_in == 0:
return torch.nn.Identity()
new_prelu = torch.nn.PReLU(n_remained_in)
new_prelu.weight.data = torch.index_select(prelu.weight.data, 0, remained_in)
return new_prelu
def replace_linear(linear, masks):
"""
This function will replace the original linear according to
the infered masks. This function support the fine-grained and
coarse-grained sparsity. In the fine-grained scenario, this function
will remove the whole column/row that happen to be totally covered by
the masks.
Parameters
----------
linear : torch.nn.Linear
The linear module to be replace
mask : ModuleMasks
The masks of this module
masks : Tuple of the input masks, output masks and weight masks
Tuple of the masks, for example
([input_m1, input_m2], [output_m], {'weight':weight_m})
Returns
-------
torch.nn.Linear
The new linear module
"""
assert isinstance(mask, ModuleMasks)
assert mask.input_mask is not None
assert mask.output_mask is None
assert not mask.param_masks
index = mask.input_mask.mask_index[-1]
in_features = index.size()[0]
_logger.debug("replace linear with new in_features: %d", in_features)
new_linear = torch.nn.Linear(in_features=in_features,
out_features=linear.out_features,
bias=linear.bias is not None)
new_linear.to(linear.weight.device)
new_linear.weight.data = torch.index_select(
linear.weight.data, -1, index.to(linear.weight.device))
in_masks, output_mask, weight_mask = masks
assert isinstance(linear, nn.Linear)
assert len(in_masks) == 1
assert isinstance(output_mask, torch.Tensor)
in_mask = in_masks[0]
weight_mask = weight_mask['weight']
# N C K
pruned_in, remained_in = convert_to_coarse_mask(in_mask, 1)
pruned_out, remained_out = convert_to_coarse_mask(output_mask, 1)
n_remained_in = weight_mask.size(1) - pruned_in.size(0)
n_remained_out = weight_mask.size(0) - pruned_out.size(0)
remained_in, remained_out = remained_in.to(
linear.weight.device), remained_out.to(linear.weight.device)
_logger.info("replace linear with new in_features: %d, out_features: %d",
n_remained_in, n_remained_out)
need_bias = False
if linear.bias is not None:
new_linear.bias.data.copy_(linear.bias.data)
need_bias = True
new_linear = torch.nn.Linear(in_features=n_remained_in,
out_features=n_remained_out,
bias=need_bias)
new_linear.to(linear.weight.device)
# Copy the remained weight from the original module
with torch.no_grad():
tmp_weight_data = torch.index_select(
linear.weight.data, 0, remained_out)
new_linear.weight.data = torch.index_select(
tmp_weight_data, 1, remained_in)
if linear.bias is not None:
new_linear.bias.data = torch.index_select(
linear.bias.data, 0, remained_out)
return new_linear
def replace_batchnorm2d(norm, mask):
def replace_batchnorm1d(norm, masks):
"""
Parameters
----------
norm : torch.nn.BatchNorm1d
The batchnorm module to be replace
masks : Tuple of the input masks, output masks and weight masks
Tuple of the masks, for example
([input_m1, input_m2], [output_m], {'weight':weight_m})
Returns
-------
torch.nn.BatchNorm1d
The new batchnorm module
"""
in_masks, output_mask, _ = masks
assert isinstance(norm, nn.BatchNorm1d)
in_mask = in_masks[0]
# N, C, H, W
_, remained_in = convert_to_coarse_mask(in_mask, 1)
_, remained_out = convert_to_coarse_mask(output_mask, 1)
assert remained_in.size(0) == remained_out.size(0)
num_features = remained_in.size(0)
_logger.info("replace batchnorm1d with num_features: %d", num_features)
new_norm = torch.nn.BatchNorm1d(num_features=num_features,
eps=norm.eps,
momentum=norm.momentum,
affine=norm.affine,
track_running_stats=norm.track_running_stats)
# assign weights
new_norm.weight.data = torch.index_select(norm.weight.data, 0, remained_in)
new_norm.bias.data = torch.index_select(norm.bias.data, 0, remained_in)
new_norm.running_mean.data = torch.index_select(
norm.running_mean.data, 0, remained_in)
new_norm.running_var.data = torch.index_select(
norm.running_var.data, 0, remained_in)
return new_norm
def replace_batchnorm2d(norm, masks):
"""
Parameters
----------
norm : torch.nn.BatchNorm2d
The batchnorm module to be replace
mask : ModuleMasks
The masks of this module
masks : Tuple of the input masks, output masks and weight masks
Tuple of the masks, for example
([input_m1, input_m2], [output_m], {'weight':weight_m})
Returns
-------
torch.nn.BatchNorm2d
The new batchnorm module
"""
assert isinstance(mask, ModuleMasks)
assert 'weight' in mask.param_masks and 'bias' in mask.param_masks
index = mask.param_masks['weight'].mask_index[0]
num_features = index.size()[0]
_logger.debug("replace batchnorm2d with num_features: %d", num_features)
in_masks, output_mask, _ = masks
assert isinstance(norm, nn.BatchNorm2d)
in_mask = in_masks[0]
# N, C, H, W
_, remained_in = convert_to_coarse_mask(in_mask, 1)
_, remained_out = convert_to_coarse_mask(output_mask, 1)
assert remained_in.size(0) == remained_out.size(0)
num_features = remained_in.size(0)
_logger.info("replace batchnorm2d with num_features: %d", num_features)
new_norm = torch.nn.BatchNorm2d(num_features=num_features,
eps=norm.eps,
momentum=norm.momentum,
affine=norm.affine,
track_running_stats=norm.track_running_stats)
# assign weights
new_norm.weight.data = torch.index_select(norm.weight.data, 0, index)
new_norm.bias.data = torch.index_select(norm.bias.data, 0, index)
if norm.track_running_stats:
new_norm.running_mean.data = torch.index_select(
norm.running_mean.data, 0, index)
new_norm.running_var.data = torch.index_select(
norm.running_var.data, 0, index)
new_norm.weight.data = torch.index_select(norm.weight.data, 0, remained_in)
new_norm.bias.data = torch.index_select(norm.bias.data, 0, remained_in)
new_norm.running_mean.data = torch.index_select(
norm.running_mean.data, 0, remained_in)
new_norm.running_var.data = torch.index_select(
norm.running_var.data, 0, remained_in)
return new_norm
def replace_conv2d(conv, mask):
def replace_conv2d(conv, masks):
"""
Replace the original conv with a new one according to the infered
masks, the function support the fine-grained sparsity and coarse-grained
sparsity. In the fine-grained scenario, this replace function will replace
the filters that happen to be totally coverd by the fine-grained sparsity.
Parameters
----------
conv : torch.nn.Conv2d
The conv2d module to be replaced
mask : ModuleMasks
The masks of this module
masks : Tuple of the input masks, output masks and weight masks
Tuple of the masks, for example
([input_m1, input_m2], [output_m], {'weight':weight_m})
Returns
-------
torch.nn.Conv2d
The new conv2d module
"""
assert isinstance(mask, ModuleMasks)
if mask.input_mask is None:
in_channels = conv.in_channels
else:
in_channels_index = mask.input_mask.mask_index[1]
in_channels = in_channels_index.size()[0]
if mask.output_mask is None:
out_channels = conv.out_channels
else:
out_channels_index = mask.output_mask.mask_index[1]
out_channels = out_channels_index.size()[0]
groups = conv.groups
if conv.in_channels == conv.out_channels == conv.groups:
# remove groups for depthwise layers
assert in_channels == out_channels
groups = in_channels
_logger.debug("replace conv2d %s with in_channels: %d, out_channels: %d",
mask.module_name, in_channels, out_channels)
new_conv = torch.nn.Conv2d(in_channels=in_channels,
out_channels=out_channels,
in_masks, output_mask, weight_masks = masks
assert isinstance(conv, nn.Conv2d)
# the conv layer should only have one input tensor
assert len(in_masks) == 1
in_mask = in_masks[0]
weight_mask = weight_masks['weight']
pruned_in, remained_in = convert_to_coarse_mask(in_mask, 1)
pruned_out, remained_out = convert_to_coarse_mask(output_mask, 1)
n_remained_in = weight_mask.size(1) * conv.groups - pruned_in.size(0)
n_remained_out = weight_mask.size(0) - pruned_out.size(0)
assert n_remained_in == remained_in.size(0)
assert n_remained_out == remained_out.size(0)
k_size1, k_size2 = conv.kernel_size
# Note: We should resolve the group dependency of the conv layers before
# run into here.
# check if the mask tensor meets the group dependency and calculate the
# new number of the groups after pruning
# the original step size of the input channel for each group
ori_inchannel_step = int(conv.in_channels/conv.groups)
# the original step size of the output channel for each group
ori_outchannel_step = int(conv.out_channels/conv.groups)
# calculate the new_in_channel_step and new_outchannel_step first
new_inchannel_step = new_outchannel_step = None
for groupid in range(conv.groups):
in_start = groupid * ori_inchannel_step
in_end = in_start + ori_inchannel_step
out_start = groupid * ori_outchannel_step
out_end = out_start + ori_outchannel_step
current_input_index = list(
filter(lambda x: in_start <= x and x < in_end, remained_in.tolist()))
current_output_index = list(
filter(lambda x: out_start <= x and x < out_end, remained_out.tolist()))
# remap the global index to the group index
if len(current_input_index) == 0:
# if the whole group are pruned
continue
else:
new_inchannel_step = len(current_input_index)
new_outchannel_step = len(current_output_index)
break
tmp_weight = torch.ones(
n_remained_out, new_inchannel_step, k_size1, k_size2)
tmp_weight = tmp_weight.to(conv.weight.device)
assert n_remained_in % new_inchannel_step == 0
assert n_remained_out % new_outchannel_step == 0
new_groups = 0
for groupid in range(conv.groups):
in_start = groupid * ori_inchannel_step
in_end = in_start + ori_inchannel_step
out_start = groupid * ori_outchannel_step
out_end = out_start + ori_outchannel_step
current_input_index = list(
filter(lambda x: in_start <= x and x < in_end, remained_in.tolist()))
current_output_index = list(
filter(lambda x: out_start <= x and x < out_end, remained_out.tolist()))
# remap the global index to the group index
current_input_index = [x-in_start for x in current_input_index]
if len(current_input_index) == 0:
# if the whole group are pruned
assert len(current_output_index) == 0
continue
# check if the number of remained channel of each group are the same
assert len(current_input_index) == new_inchannel_step
assert len(current_output_index) == new_outchannel_step
# copy the weight into tmp_weight
new_out_start = new_outchannel_step * new_groups
new_out_end = new_out_start + new_outchannel_step
tmp_weight[new_out_start:new_out_end] = torch.index_select(
conv.weight[current_output_index], 1, torch.as_tensor(current_input_index, dtype=torch.long).to(conv.weight.device))
new_groups += 1
_logger.debug("replace conv2d with in_channels: %d, out_channels: %d",
n_remained_in, n_remained_out)
# need_bias is a flag that indicates that if a conv layer need
# bias, if the original conv doesn't have a bias and there is
# no constant need to be folded into the bias, the need_bias is False.
need_bias = conv.bias is not None
new_conv = torch.nn.Conv2d(in_channels=n_remained_in,
out_channels=n_remained_out,
kernel_size=conv.kernel_size,
stride=conv.stride,
padding=conv.padding,
dilation=conv.dilation,
groups=groups,
bias=conv.bias is not None,
groups=new_groups,
bias=need_bias,
padding_mode=conv.padding_mode)
new_conv.to(conv.weight.device)
tmp_weight_data = tmp_bias_data = None
if mask.output_mask is not None:
tmp_weight_data = torch.index_select(
conv.weight.data, 0, out_channels_index)
if conv.bias is not None:
tmp_bias_data = torch.index_select(
conv.bias.data, 0, out_channels_index)
else:
tmp_weight_data = conv.weight.data
# For the convolutional layers that have more than one group
# we need to copy the weight group by group, because the input
# channal is also divided into serveral groups and each group
# filter may have different input channel indexes.
input_step = int(conv.in_channels / conv.groups)
in_channels_group = int(in_channels / groups)
filter_step = int(out_channels / groups)
if mask.input_mask is not None and not (in_channels == out_channels == groups):
for groupid in range(conv.groups):
start = groupid * input_step
end = (groupid + 1) * input_step
current_input_index = list(
filter(lambda x: start <= x and x < end, in_channels_index.tolist()))
if not current_input_index:
# there is no kept channel in current group
# TODO bug here, the groups is directly get from conv.groups, if the whole group is removed,
# then the number of groups in the new_conv also need to change
raise Exception(
" Donnot support removing the whole group filter except in the depth-wise conv temporarily")
# shift the global index into the group index
current_input_index = [x-start for x in current_input_index]
# if the groups is larger than 1, the input channels of each
# group should be pruned evenly.
assert len(current_input_index) == in_channels_group, \
'Input channels of each group are not pruned evenly'
current_input_index = torch.tensor(current_input_index).to(tmp_weight_data.device) # pylint: disable=not-callable
f_start = groupid * filter_step
f_end = (groupid + 1) * filter_step
new_conv.weight.data[f_start:f_end] = torch.index_select(
tmp_weight_data[f_start:f_end], 1, current_input_index)
else:
new_conv.weight.data.copy_(tmp_weight_data)
new_conv.weight.copy_(tmp_weight)
# copy the bias data
if conv.bias is not None:
new_conv.bias.data.copy_(
conv.bias.data if tmp_bias_data is None else tmp_bias_data)
new_conv.bias.data.copy_(torch.index_select(
conv.bias.data, 0, remained_out))
return new_conv
def replace_convtranspose2d(convtrans, mask):
def replace_convtranspose2d(convtrans, masks):
"""
We need anothor replace function for
convtranspose2d, because the layout of
......@@ -202,81 +400,120 @@ def replace_convtranspose2d(convtrans, mask):
----------
convtrans : torch.nn.ConvTranspose2d
The conv2d module to be replaced
mask : ModuleMasks
The masks of this module
masks : Tuple of the input masks, output masks and weight masks
Tuple of the masks, for example
([input_m1, input_m2], [output_m], {'weight':weight_m})
Returns
-------
torch.nn.ConvTranspose2d
The new conv2d module
"""
assert isinstance(mask, ModuleMasks)
in_masks, output_mask, weight_masks = masks
assert isinstance(convtrans, torch.nn.ConvTranspose2d)
if mask.input_mask is None:
in_channels = convtrans.in_channels
else:
in_channels_index = mask.input_mask.mask_index[1]
in_channels = in_channels_index.size(0)
if mask.output_mask is None:
out_channels = convtrans.out_channels
else:
out_channels_index = mask.output_mask.mask_index[1]
out_channels = out_channels_index.size(0)
groups = convtrans.groups
# check if can remove the whole group of filters
if convtrans.in_channels == convtrans.out_channels == convtrans.groups:
# remove groups for depthwise layers
# this needs the group dependency to be fixed before the speedup
assert in_channels == out_channels
groups = in_channels
_logger.debug('Replace convtranspose2d %s with in_channels:%d out_channels:%d',
mask.module_name, in_channels, out_channels)
new_convtrans = torch.nn.ConvTranspose2d(in_channels=in_channels,
out_channels=out_channels,
assert len(in_masks) == 1
in_mask = in_masks[0]
weight_mask = weight_masks['weight']
pruned_in, remained_in = convert_to_coarse_mask(in_mask, 1)
pruned_out, remained_out = convert_to_coarse_mask(output_mask, 1)
# ConvTranspose2d has the weight shape of [N_in, N_out/groups, k1, k2]
n_remained_in = weight_mask.size(0) - pruned_in.size(0)
n_remained_out = weight_mask.size(
1) * convtrans.groups - pruned_out.size(0)
assert n_remained_in == remained_in.size(0)
assert n_remained_out == remained_out.size(0)
k_size1, k_size2 = convtrans.kernel_size
# Note: we should resolve the group dependency of the convtrans layers before
# run into this function
ori_inchannel_step = int(convtrans.in_channels/convtrans.groups)
ori_outchannel_step = int(convtrans.out_channels/convtrans.groups)
new_inchannel_step = new_outchannel_step = None
for groupid in range(convtrans.groups):
in_start = groupid * ori_inchannel_step
in_end = in_start + ori_inchannel_step
out_start = groupid * ori_outchannel_step
out_end = out_start + ori_outchannel_step
current_input_index = list(
filter(lambda x: in_start <= x and x < in_end, remained_in.tolist()))
current_output_index = list(
filter(lambda x: out_start <= x and x < out_end, remained_out.tolist()))
if len(current_input_index) == 0:
# if the whole group are pruned
continue
else:
new_inchannel_step = len(current_input_index)
new_outchannel_step = len(current_output_index)
break
tmp_weight = torch.ones(
n_remained_in, new_outchannel_step, k_size1, k_size2)
tmp_weight = tmp_weight.to(convtrans.weight.device)
assert n_remained_in % new_inchannel_step == 0
assert n_remained_out % new_outchannel_step == 0
new_groups = 0
for groupid in range(convtrans.groups):
# copy the weights of this group
in_start = groupid * ori_inchannel_step
in_end = in_start + ori_inchannel_step
out_start = groupid * ori_outchannel_step
out_end = out_start + ori_outchannel_step
current_input_index = list(
filter(lambda x: in_start <= x and x < in_end, remained_in.tolist()))
current_output_index = list(
filter(lambda x: out_start <= x and x < out_end, remained_out.tolist()))
# remap the global index to the group index
# in the convtranspose layer, the groups are on
# the output channel dimension
current_output_index = [x-out_start for x in current_output_index]
if len(current_input_index) == 0:
# if the whole group are pruned
assert len(current_output_index) == 0
continue
# check if the number of remained channel of each group are the same
assert len(current_input_index) == new_inchannel_step
assert len(current_output_index) == new_outchannel_step
# copy the weight into tmp_weight
new_in_start = new_inchannel_step * new_groups
new_in_end = new_in_start + new_inchannel_step
tmp_weight[new_in_start:new_in_end] = torch.index_select(
convtrans.weight[current_input_index], 1, torch.as_tensor(current_output_index, dtype=torch.long).to(convtrans.weight.device))
new_groups += 1
_logger.debug('Replace convtranspose2d with in_channels:%d out_channels:%d',
n_remained_in, n_remained_out)
new_convtrans = torch.nn.ConvTranspose2d(in_channels=n_remained_in,
out_channels=n_remained_out,
kernel_size=convtrans.kernel_size,
stride=convtrans.stride,
padding=convtrans.padding,
dilation=convtrans.dilation,
groups=groups,
groups=new_groups,
bias=convtrans.bias is not None,
padding_mode=convtrans.padding_mode)
new_convtrans.to(convtrans.weight.device)
tmp_weight_data = None
if mask.input_mask is not None:
# in convtranspose2d we need to select the input channel first
tmp_weight_data = torch.index_select(
convtrans.weight.data, 0, in_channels_index)
else:
tmp_weight_data = convtrans.weight.data
# we need to handle the output channel group by group like the conv layer
out_step = int(convtrans.out_channels / convtrans.groups)
out_channel_group = int(out_channels/groups)
new_in_per_group = int(in_channels/groups)
if mask.output_mask is not None and not(in_channels == out_channels == groups):
for groupid in range(convtrans.groups):
start = groupid * out_step
end = (groupid + 1) * out_step
current_output_index = list(
filter(lambda x: start <= x and x < end, out_channels_index.tolist()))
# we need to shift the index into the group-wise
current_output_index = [x-start for x in current_output_index]
if not current_output_index:
# No kept channel in the current group
raise Exception(
" Donnot support removing the whole group filter except in the depth-wise conv temporarily")
assert len(current_output_index) == out_channel_group, \
'Output channel of each group should be the same after pruning'
current_output_index = torch.tensor(current_output_index).to(tmp_weight_data.device) # pylint: disable=not-callable
new_start = groupid * new_in_per_group
new_end = (groupid + 1) * new_in_per_group
new_convtrans.weight.data[new_start:new_end] = torch.index_select(
tmp_weight_data[new_start:new_end], 1, current_output_index)
else:
new_convtrans.weight.data.copy_(tmp_weight_data)
new_convtrans.weight.copy_(tmp_weight)
if convtrans.bias is not None:
if mask.output_mask is not None:
if output_mask is not None:
new_convtrans.bias.data[:] = torch.index_select(
convtrans.bias.data, 0, out_channels_index)
convtrans.bias.data, 0, remained_out)
else:
new_convtrans.bias.data.copy_(convtrans.bias.data)
return new_convtrans
def replace_layernorm(layernorm, masks):
in_masks, _, _ = masks
assert isinstance(layernorm, nn.LayerNorm)
assert len(in_masks) == 1
in_mask = in_masks[0]
dim_n = len(in_mask.size())
new_shape = []
for i in range(1, dim_n):
sum_dims = list(range(0, dim_n))
sum_dims.remove(i)
reduced = torch.sum(in_mask, sum_dims)
n_remained = torch.sum(reduced > 0)
new_shape.append(n_remained)
return nn.LayerNorm(tuple(new_shape), layernorm.eps, layernorm.elementwise_affine)
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import queue
import logging
import copy
import torch
import torch.nn as nn
from nni.common.graph_utils import build_module_graph
from nni.compression.pytorch.utils.mask_conflict import fix_mask_conflict
from nni.compression.pytorch.utils.utils import get_module_by_name
from .compress_modules import replace_module
from .infer_shape import ModuleMasks, infer_from_mask, infer_from_inshape, infer_from_outshape, set_conv_prune_dim
from .infer_mask import AutoMaskInference
from .jit_translate import jit_to_python_function
from ..utils import rand_like_with_shape
_logger = logging.getLogger(__name__)
_logger.setLevel(logging.INFO)
class ModelSpeedup:
"""
This class is to speedup the model with provided weight mask
This class is to speedup the model with provided weight mask.
"""
def __init__(self, model, dummy_input, masks_file, map_location=None):
def __init__(self, model, dummy_input, masks_file, map_location=None,
batch_dim=0, confidence=8):
"""
Parameters
----------
model : pytorch model
The model user wants to speed up
dummy_input : pytorch tensor
The dummy input for ```jit.trace```, users should put it on right device before pass in
dummy_input : pytorch tensor, tuple of tensor, list of tensor
Note: The first dimension of the dummy_input should be the batchsize.
The dummy input for ```jit.trace```, users should put it on the right
device.
masks_file : str
The path of user provided mask file
map_location : str
the device on which masks are placed, same to map_location in ```torch.load```
batch_dim : int
the index of batch dimension in the dummy_input
confidence: the confidence coefficient of the sparsity inference. This value is
actually used as the batchsize of the dummy_input.
"""
from nni.common.graph_utils import build_module_graph
assert confidence > 1
# The auto inference will change the values of the parameters in the model
# so we need make a copy before the mask inference
self.ori_state_dict = copy.deepcopy(model.state_dict())
self.bound_model = model
self.masks = torch.load(masks_file, map_location)
self.inferred_masks = dict() # key: module_name, value: ModuleMasks
self.dummy_input = dummy_input
self.torch_graph = build_module_graph(model, dummy_input)
self.inferred_masks = dict() # key: module_name, value: ModuleMasks
self.batch_dim = batch_dim
self.dummy_input, self.device = self._random_model_input(dummy_input, confidence, batch_dim)
self.torch_graph = build_module_graph(model, self.dummy_input)
# dict object to save the auto inferences objects of the submodules
self.auto_inferences = {}
# the index dict to find the corresponding torch._C.Value object
# according to the debug name
# we need the dummy_input to infer the mask automaticlly, so we save
# the indexes from tensor's debugname to the torch._C.Value object.
self.debugname_to_value = {}
# load the mask tensor to the same device with the dummy_input
# self.masks save the mask tensors pruned by the user and the infered
# masks of the others modules
self.masks = torch.load(
masks_file, map_location if map_location is not None else str(self.device))
def infer_module_mask(self, module_name, last_module, mask=None, in_shape=None, out_shape=None):
self.constant = {}
# self.internal_result save the internal output of the submodules
self.internal_result = {}
def _random_model_input(self, dummy_input, confidence, batch_dim):
"""
Infer input shape / output shape based on the module's weight mask / input shape / output shape.
Get the new random dummy input accordint to the original dummy_input
and confidence, batch_dim.
For a module:
Infer its input and output shape from its weight mask
Infer its output shape from its input shape
Infer its input shape from its output shape
Parameters
----------
dummy_input: Tensor or list/dict of Tensors
The dummy_input given by the user.
confidence: int
The new batch size of the generated dummy_input.
batch_dim: int
The index of the batch dimension.
Returns
------
new_dummy_input: Tensor or list/dict of Tensors
The generated dummy_input for mask inference.
device: torch.device
The device of the generated dummy_inputs
"""
input_errmsg = 'Only support the tensor, list/tuple/dict of tensors as input'
# Some model may use list of tensors as input, for example transformers
new_dummy_input, device = None, None
if isinstance(dummy_input, torch.Tensor):
input_shape = list(dummy_input.size())
# set the batchsize to the confidence ratio
input_shape[batch_dim] = confidence
new_dummy_input = rand_like_with_shape(input_shape, dummy_input)
device = dummy_input.device
elif isinstance(dummy_input, (tuple, list)):
# else if the dummy input is list/tuple
new_dummy_input = []
old_batchsize = dummy_input[0].size(0)
device = dummy_input[0].device
for _, t_input in enumerate(dummy_input):
assert isinstance(t_input, torch.Tensor), input_errmsg
assert t_input.size(0) == old_batchsize, 'The first dimension should be batchsize\
and the batchsize of all inputs should be the same!'
input_shape = list(t_input.size())
input_shape[batch_dim] = confidence
# rand_func = torch.randint if t_input.dtype
new_dummy_input.append(
rand_like_with_shape(input_shape, t_input))
elif isinstance(dummy_input, dict):
new_dummy_input = {}
tmp_key = list(dummy_input.keys())[0]
old_batchsize = dummy_input[tmp_key].size(0)
device = dummy_input[tmp_key].device
for in_name, t_input in dummy_input.items():
assert isinstance(t_input, torch.Tensor), input_errmsg
assert old_batchsize == t_input.size(0), 'The first dimension should be batchsize\
and the batchsize of all inputs should be the same!'
input_shape = list(t_input.size())
input_shape[batch_dim] = confidence
new_dummy_input[in_name] = rand_like_with_shape(
input_shape, t_input)
else:
raise TypeError(input_errmsg)
return new_dummy_input, device
If its input shape is changed, continue infering its predecessors
If its output shape is changed, continue infering its successors
def _prepare_dummy_input(self, node):
"""
Prepare the dummy_input for the auto mask inference.
Parameters
----------
module_name : str
The name of the node
last_module : str
The name of last visited node
mask : tensor of mask or ModuleMasks
Mask of the weights in this node (i.e., module)
in_shape : ModuleMasks
Input shape of this node
out_shape : ModuleMasks
Output shape of this node
"""
input_cmask = output_cmask = None
if module_name in self.inferred_masks:
module_masks = self.inferred_masks[module_name]
node: NodePyGroup
Returns
-------
dummy_input: list
List of tensors that will be used as input for the target node.
debugnames: list of strs
Debugnames of the dummy_inputs.
"""
_logger.debug('Prepare auto mask inference for node: %s',
node.unique_name)
# prepare the inputs and outputs mask for this node,
# if there is already a mask in self.masks, then use
# the original mask tensor, else create a new one.
inputs_name = node.inputs
# build the dummy_input, in_masks the target node
dummy_input = []
debugnames = []
for _input in inputs_name:
if _input not in self.internal_result:
# if the input debug name is not in self.internal_result,
# then this node isn't a output tensor of any predecessor
# nodes. This node is a attribute of the submodule, such as
# weight or bias, etc. We will skip these tensors.
# If we don't want this specific judgement here, we can merge
# the `prim::GetAttr` node of the weight/bias tensor into the key
# node, such as `conv`.
# This is caused by the `meage_module_node` function in the
# _graph_utils.py, because it doesn't merge the prim::GetAttr
# node into the key node. In current version of _graph_utils.py,
# we will only merge the nodes that have same scope name, however,
# the scope name of the correponding prim::GetAttr node of `weight` tensor
# is None.
continue
# The detach operation here is for the in-place operation. We cannot
# directly can the backward on the output tensor of an in-place operator.
dummy_input.append(self.internal_result[_input].detach())
debugnames.append(_input)
return dummy_input, debugnames
def update_direct_sparsity(self, node):
"""
Update the direct sparsity for the target node. Here the direct sparsity
means that the sparsity in the output tensor that caused by the sparsity
in the input tensors/weight tensors.
"""
# this name is consistent with the name returned by named_modules()
module_name = node.name
_logger.info('Update mask for %s', module_name)
unique_name = node.unique_name
dummy_input, input_debugname = self._prepare_dummy_input(node)
# get the input mask from self.masks
# Note: the input mask of the successor nodes are
# already created by the predecessor node
in_masks = [self.masks[debugname] for debugname in input_debugname]
in_constants = [self.constant[debugname]
for debugname in input_debugname]
if node.type == 'func':
# we cannot get the runable function directly from the jit traced
# graph, so we translate it back to python function, Note: the function
# is appliable to both cpu/gpu devices, the output tensors will be on the
# same device of the input tensors
func = jit_to_python_function(node, self)
if func is None:
# no need to infer the sparsity for this node
self.auto_inferences[unique_name] = None
return
# function doesn't have weights
_auto_infer = AutoMaskInference(
func, dummy_input, in_masks, in_constants=in_constants, batch_dim=self.batch_dim)
else:
_, m = get_module_by_name(self.bound_model, module_name)
module_masks = ModuleMasks(module_name, m)
self.inferred_masks[module_name] = module_masks
m_type = self.torch_graph.name_to_node[module_name].op_type
_logger.debug("infer mask of module %s with op_type %s", module_name, m_type)
if mask is not None:
_logger.debug("mask is not None")
if not m_type in infer_from_mask:
raise RuntimeError(
"Has not supported infering input/output shape from mask for module/function: `{}`, {}"
.format(m_type, module_name))
if m_type in ['Linear']:
input_cmask, output_cmask = infer_from_mask[m_type](
module_masks, mask, self.torch_graph.name_to_node[module_name].auxiliary
)
else:
input_cmask, output_cmask = infer_from_mask[m_type](module_masks, mask)
if in_shape is not None:
_logger.debug("in_shape is not None")
if not m_type in infer_from_inshape:
raise RuntimeError(
"Has not supported infering output shape from input shape for module/function: `{}`, {}"
.format(m_type, module_name))
if m_type in ['aten::view', 'aten::flatten', 'aten::mean', 'aten::reshape']:
output_cmask = infer_from_inshape[m_type](module_masks,
in_shape,
self.torch_graph.name_to_node[module_name].auxiliary)
elif m_type in ['aten::cat']:
# To calculate the mask for concat operation, the output shape
# , cat dimension, and the order of the input parameters.
output_cmask = infer_from_inshape[m_type](module_masks,
in_shape,
self.torch_graph.name_to_node[module_name].auxiliary,
last_module)
else:
output_cmask = infer_from_inshape[m_type](module_masks, in_shape)
if out_shape is not None:
_logger.debug("out_shape is not None")
if not m_type in infer_from_outshape:
raise RuntimeError(
"Has not supported infering input shape from output shape for module/function: `{}`, {}"
.format(m_type, module_name))
if m_type in ['aten::view', 'aten::flatten', 'aten::mean', 'aten::reshape']:
input_cmask = infer_from_outshape[m_type](module_masks, out_shape, self.torch_graph.name_to_node[module_name].auxiliary)
else:
input_cmask = infer_from_outshape[m_type](module_masks, out_shape)
weight_mask = None
if module_name in self.masks:
weight_mask = self.masks[module_name]
_, module = get_module_by_name(self.bound_model, module_name)
_auto_infer = AutoMaskInference(
module, dummy_input, in_masks, weight_mask, in_constants=in_constants,
state_dict=copy.deepcopy(module.state_dict()), batch_dim=self.batch_dim)
self.auto_inferences[unique_name] = _auto_infer
_auto_infer.name = node.unique_name
_auto_infer.update_direct_sparsity()
# also save the input debug names into the auto_infer
_auto_infer.input_debugname = input_debugname
# update the mask tensor and the internal output of the submodules
# after manually unpack the tuple/list of tensors, the number of the outputs
# of each node should always be one(Except for the TupleUnpack node at the end
# of the whole model)
assert len(
node.outputs) == 1, 'The number of the output should be one after the Tuple unpacked manually'
out_debugname = node.outputs[0]
# update the output mask into self.masks
self.masks[out_debugname] = _auto_infer.output_mask
self.constant[out_debugname] = _auto_infer.out_constant
# update the output result into self.internal_result, so that
# the successor nodes can take these output tensors as inputs.
self.internal_result[out_debugname] = _auto_infer.output
# update the parameter mask of the node
if input_cmask:
predecessors = self.torch_graph.find_predecessors(module_name)
for _module_name in predecessors:
self.infer_module_mask(_module_name, module_name, out_shape=input_cmask)
if output_cmask:
successors = self.torch_graph.find_successors(module_name)
for _module_name in successors:
self.infer_module_mask(_module_name, module_name, in_shape=output_cmask)
self.masks[module_name] = _auto_infer.weight_mask
def update_indirect_sparsity(self, node):
"""
This function will update the indirect sparsity. To explain what's
indirect sparsity, for example, there is two tensors TA and TB, and
we perform the calculation: TC = TA x TB in which TC is also a tensor.
Once some values in TA are masked to zeros, then the corresponding
positions in TB are also potential sparsities, because these have no
effect of the final output(the gradient of these positions in TB equal
to 0 all the time). This function it to fine the potential sparsity caused
by other sparsity(we call it indirect sparsity here). Basically we can find
these potential sparsity through gradient.
Parameters
---------
node: the NodePy
The target node to update the indirect sparsity
"""
module_name = node.name
_logger.info('Update indirect sparsity for %s', module_name)
unique_name = node.unique_name
if unique_name in self.auto_inferences and self.auto_inferences[unique_name] is not None:
# if the auto inference object already in self.auto_inference, then
# directly update the previous one
# self.auto_inferences[unique_name].update()
_logger.info(
'Update the indirect sparsity for the %s', unique_name)
auto_infer = self.auto_inferences[unique_name]
auto_infer.update_indirect_sparsity()
# pass the gradient to the predecessor nodes
for in_id, tin in enumerate(auto_infer.dummy_input):
debug_name = auto_infer.input_debugname[in_id]
last_output = self.internal_result[debug_name]
# if isinstance(last_output, torch.Tensor):
# TODO what if last output is tuple/list of tensor
if last_output.grad is not None and tin.grad is not None:
last_output.grad.data += tin.grad.data
else:
last_output.grad = tin.grad
else:
_logger.warning('Note: %s does not have corresponding mask inference object', node.name)
def _vnode_to_value(self, c_node):
"""
translate the C Value node into the values/tensors.
"""
errmsg = "Only support the torch._C.Value type"
assert isinstance(c_node, torch._C.Value), errmsg
if isinstance(c_node.type(), torch._C.TensorType):
shape = tuple(c_node.type().sizes())
dtype = c_node.type().scalarType()
# TODO should use a more general way to get the input
if dtype.startswith('Float') or dtype.startswith('Double'):
return torch.rand(shape).to(self.device)
else:
# This small range is due to the ·ReLU6·, we will add
# the manual specific mask inference rule for several
# ops in the future, so that we can remove the constraint.
return torch.randint(0, 10, shape, device=self.device)
else:
value = c_node.toIValue()
# TODO support more kinds of value node
errmsg = "Doesn't support convert %s to values", str(c_node.type())
# currently only support the tensors and constant values
assert value is not None, errmsg
return value
def infer_modules_masks(self):
"""
Do shape inference of involved modules, including the shape of weights, inputs, output
"""
for module_name, mask in self.masks.items():
_logger.debug('Start mask inference from %s', module_name)
if module_name not in self.torch_graph.name_to_node:
# this module is not traced in the torch_graph,
# jit.trace only correctly records functions and
# modules which are not data dependent (e.g., do
# not have conditionals on data in tensors)
# so, if a node is not traced, we just skip it.
_logger.warning('%s has mask, but not found in the traced graph, just skip it.', module_name)
continue
self.infer_module_mask(module_name, None, mask=mask)
Infer the mask for all layers in the module, this function can be divided into
two steps: first, forward inference of the the masks. Second, backward inference
of the mask. We keep repeating these two steps until the masks of the model doesn't
change.
"""
# unpack the tensor tuple/list before the mask inference
self.torch_graph.unpack_manually()
# find the input/ouput tensor of the whole graph
graph_input = []
graph_output = []
for name, nodeio in self.torch_graph.nodes_py.nodes_io.items():
if nodeio.input_or_output == 'input':
graph_input.append((name, nodeio))
# also put the graph input tensor into the internal_result
# TODO if we can find the corresponding relation between the value node
# and the dummy_inputs, we can use the inputs value in the dummy_input
value = self._vnode_to_value(self.debugname_to_value[name])
self.internal_result[name] = value
# create the mask tensor for the input value
if isinstance(self.internal_result[name], torch.Tensor):
self.masks[name] = torch.ones_like(value)
self.constant[name] = torch.zeros_like(value)
elif nodeio.input_or_output == 'output':
graph_output.append((name, nodeio))
# count the degree for the node in the graph
in_degree = {}
out_degree = {}
visit_queue = queue.Queue()
for node in self.torch_graph.nodes_py.nodes_op:
successors = self.torch_graph.find_successors(node.unique_name)
out_degree[node.unique_name] = len(successors)
predecessors = self.torch_graph.find_predecessors(node.unique_name)
in_degree[node.unique_name] = len(predecessors)
if in_degree[node.unique_name] == 0:
visit_queue.put(node)
# Forward mask inference
while not visit_queue.empty():
curnode = visit_queue.get()
# forward mask inference for curnode
self.update_direct_sparsity(curnode)
successors = self.torch_graph.find_successors(curnode.unique_name)
for successor in successors:
in_degree[successor] -= 1
if in_degree[successor] == 0:
visit_queue.put(self.torch_graph.name_to_node[successor])
# backward mask inference
for unique_name in out_degree:
if out_degree[unique_name] == 0:
visit_queue.put(self.torch_graph.name_to_node[unique_name])
while not visit_queue.empty():
curnode = visit_queue.get()
self.update_indirect_sparsity(curnode)
predecessors = self.torch_graph.find_predecessors(
curnode.unique_name)
for predecessor in predecessors:
out_degree[predecessor] -= 1
if out_degree[predecessor] == 0:
visit_queue.put(self.torch_graph.name_to_node[predecessor])
def replace_compressed_modules(self):
"""
......@@ -148,40 +377,138 @@ class ModelSpeedup:
NOTE: ```func``` type cannot be replaced as it is not a module, thus, one limitation
is that ```func``` should be not required to be replaced.
"""
for module_name in self.inferred_masks:
g_node = self.torch_graph.name_to_node[module_name]
_logger.debug("replace %s, in %s type, with op_type %s",
module_name, g_node.type, g_node.op_type)
if g_node.type == 'module':
super_module, leaf_module = get_module_by_name(self.bound_model, g_node.name)
m_type = g_node.op_type
if not m_type in replace_module:
raise RuntimeError("Has not supported replacing the module: `{}`".format(m_type))
_logger.info("replace module (name: %s, op_type: %s)", g_node.name, m_type)
compressed_module = replace_module[m_type](leaf_module, self.inferred_masks[module_name])
setattr(super_module, g_node.name.split('.')[-1], compressed_module)
elif g_node.type == 'func':
_logger.info("Warning: cannot replace (name: %s, op_type: %s) which is func type",
module_name, g_node.op_type)
else:
raise RuntimeError("Unsupported node type: {}".format(g_node.type))
with torch.no_grad():
for unique_name in self.auto_inferences:
self.replace_submodule(unique_name)
def replace_submodule(self, unique_name, reindex_dim=None, reindex=None):
"""
Replace the submodule according to the inferred sparsity.
unique_name: str
The unique_name of the submodule to replace.
reindex_dim: int
The dimension of the re-index operation.
reindex: Reindex
The index tensor. Normally this variable is None. If we want to reindex the
output of this submodule, we can pass the index by this parameter.
"""
class ReindexModule(nn.Module):
"""
ReindexModule is used to resolve the mask conflict when replace the submodule.
Basically, we can use two ways to resolve the mask conflict: (1) unmask some
values(will introduce more computation overhead) (2) reindex and padd the output
tensor of the target op(introduce more memory access overhad). Currently this
method is shutdown, in the future, we will merge these two methods into a graph
pass which is used to resolve the mask conflict.
"""
def __init__(self, ori_module, reindex_dim, reindex):
super(ReindexModule, self).__init__()
self.ori_module = ori_module
self.reindex_dim = reindex_dim
self.reindex = reindex
tmp_index = [slice(None, None) for i in range(reindex_dim+1)]
# the index for the tensor
tmp_index[reindex_dim] = reindex
self.t_index = tuple(tmp_index)
def forward(self, x):
tmpout = self.ori_module(x)
shape = list(tmpout.size())
shape[self.reindex_dim] = self.reindex.size(0)
out = torch.zeros(tuple(shape), device=tmpout.device,
requires_grad=tmpout.requires_grad)
out[self.t_index] = tmpout
return out
assert unique_name in self.auto_inferences
g_node = self.torch_graph.name_to_node[unique_name]
_logger.debug("replace %s, in %s type, with op_type %s",
unique_name, g_node.type, g_node.op_type)
auto_infer = self.auto_inferences[unique_name]
if g_node.type == 'module':
if g_node.unique_name in self.torch_graph.reused_module:
if reindex_dim is not None:
_logger.warning(
'Cannot replace a reused module with padding operator!!')
return None
super_module, leaf_module = get_module_by_name(
self.bound_model, g_node.name)
m_type = g_node.op_type
if not m_type in replace_module:
raise RuntimeError(
"Has not supported replacing the module: `{}`".format(m_type))
_logger.info("replace module (name: %s, op_type: %s)",
g_node.name, m_type)
compressed_module = replace_module[m_type](
leaf_module, auto_infer.get_masks())
new_submodule = compressed_module
if reindex_dim is None:
setattr(super_module, g_node.name.split(
'.')[-1], compressed_module)
elif reindex_dim is not None and reindex is not None:
# reindex the output of this submodule and replace the orginal module
new_submodule = ReindexModule(
compressed_module, reindex_dim, reindex)
setattr(super_module, g_node.name.split(
'.')[-1], new_submodule)
return new_submodule
elif g_node.type == 'func':
_logger.info("Warning: cannot replace (name: %s, op_type: %s) which is func type",
unique_name, g_node.op_type)
return None
else:
raise RuntimeError("Unsupported node type: {}".format(g_node.type))
def initialize_speedup(self):
"""
Do some initial work for speedup.
"""
# initialize the self.debugname_to_value
# build a mapping table from the debug name of the tensor
# to its value node in the graph
traced_graph = self.torch_graph.trace.graph
for node in traced_graph.nodes():
for _input in node.inputs():
debug_name = _input.debugName()
if debug_name not in self.debugname_to_value:
self.debugname_to_value[debug_name] = _input
for _output in node.outputs():
debug_name = _output.debugName()
if debug_name not in self.debugname_to_value:
self.debugname_to_value[debug_name] = _output
# put the model itself into internel_result to perform the
# value inference for the 'prim::GetAttr', the first ClassType
# of the whole graph is the model class
for graph_input in traced_graph.inputs():
if graph_input.type().kind() == 'ClassType':
self.internal_result[graph_input.debugName()
] = self.bound_model
break
def speedup_model(self):
"""
There are basically two steps:
first, do mask/shape inference,
second, replace modules
There are basically two steps: first, do mask/shape inference,
second, replace modules.
"""
training = self.bound_model.training
_logger.info("start to speed up the model")
_logger.info("fix the mask conflict of the interdependent layers")
_, conv_prune_dim = fix_mask_conflict(self.masks, self.bound_model, self.dummy_input)
set_conv_prune_dim(conv_prune_dim)
_logger.info("start to speed up the model")
self.initialize_speedup()
training = self.bound_model.training
# set to the evaluation mode
self.bound_model.train(False)
# TODO suppose to fix the conflict after the sparsity propagation
# which is more elegent
fix_mask_conflict(self.masks, self.bound_model, self.dummy_input)
_logger.info("infer module masks...")
self.infer_modules_masks()
_logger.info('resolve the mask conflict')
# load the original stat dict before replace the model
self.bound_model.load_state_dict(self.ori_state_dict)
_logger.info("replace compressed modules...")
# the mask conflict should be already resolved
self.replace_compressed_modules()
self.bound_model.train(training)
_logger.info("speedup done")
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import logging
import torch
import torch.nn as nn
from ..utils import randomize_tensor, torch_float_dtype, torch_integer_dtype
_logger = logging.getLogger(__name__)
_logger.setLevel(logging.INFO)
STD_DELTA = 1e-6
class AutoMaskInference:
def __init__(self, module, dummy_input, in_masks=None, weight_mask=None, \
output_mask=None, name=None, in_constants=None, state_dict=None, batch_dim=0):
"""
This class will infer the mask of the target module automatically.
This update_direct_sparsity will infer the output mask according
to the input masks, in constrast, update_indirect_sparsity will
infer the input masks according to given output masks. The newly
found sparsity will be incrementally updated to the original in_masks
and output_mask.
Parameters
----------
module: torch.nn.Module/function
The target module to infer the mask. Need to be callable.
dummy_input: torch.Tensor/list of Tensor
The dummy_input of the target module.
in_masks: list of torch.Tensor
The input masks of the target module, if in_masks is not None, then
update_direct_sparsity and update_indirect_sparsity will incrementally
update the given in_masks, else, AutoMaskInference will create a new
in_masks for the target module.
output_mask: torch.Tensor
The output mask of the target module. Similar to in_masks, if output_mask
is not None, then update_direct_sparsity and update_indirect_sparsity will
incrementally update the given output_mask, else AutoMaskInference will create
one output_mask for the target module.
weight_mask: dict of the weight masks
The weight masks of the target module, the key is the corresponding name of
the mask. For example: {'weight':torch.ones(1000, 1000), bias:torch.ones(1000)}
name: str
Name of the target module.
in_constants: list of torch.Tensor
The correponding constant values of the in_masks.
state_dict: dict of torch.Tensor
The original values of the weights.
batch_dim: int
The index of the batch dimension of the input tensors.
"""
errmsg = '%s is not callable, should pass the nn.Module/function' % str(
module)
assert callable(module), errmsg
self.module = module
# Initialize the dummy_input
if isinstance(dummy_input, list):
# if there are multiple input variables
self.dummy_input = dummy_input
else:
# if there is only one input variable
self.dummy_input = [dummy_input]
# Initialize the masks for input tensors
self.in_masks = in_masks if in_masks is not None else [
None] * len(self.dummy_input)
self.in_constants = in_constants if in_constants is not None else [
torch.zeros_like(x) for x in dummy_input]
for in_id, _ in enumerate(self.in_masks):
if self.in_masks[in_id] is None and \
isinstance(self.dummy_input[in_id], torch.Tensor):
# if the input mask is None then create a all-ones mask for corresponding input tensor
self.in_masks[in_id] = torch.ones_like(self.dummy_input[in_id])
# ones_like will put the created mask on the same device with the dummy_input
# Initialize the mask for output tensors
self.output = self.module(*dummy_input)
# self.output.requires_grad_()
if output_mask is not None:
# assume the given output mask is right
self.output_mask = output_mask
else:
if isinstance(self.output, torch.Tensor):
self.output_mask = torch.ones_like(self.output)
elif isinstance(self.output, list) or isinstance(self.output, tuple):
self.output_mask = []
for o_tensor in self.output:
if isinstance(o_tensor, torch.Tensor):
self.output_mask.append(torch.ones_like(o_tensor))
else:
# if one of the outputs is not tensor, set the corresponding
# mask to None
self.output_mask.append(None)
else:
self.output_mask = None
# Initialize the mask for the parameters
self.weights = {}
self.weight_mask = {}
if weight_mask:
self.weight_mask.update(weight_mask)
if isinstance(self.module, nn.Module):
# the function should not has parameters
# get all the parameter tensors of the target module
for name, para in module.named_parameters():
self.weights[name] = para
if name not in self.weight_mask:
self.weight_mask[name] = torch.ones_like(para.data)
self.name = name
self.state_dict = state_dict
# TODO support the other batch dimension in the future
self.batch_dim = batch_dim
def random_init(self, start=0.1, end=8.0):
"""
Random initialize the weights of the module. The value of
the tensor will not affect the mask auto inference.
"""
# currently we set the random range to 0.1-8.0 because of the ReLU6,
# if we use a range that far larger than 6, it may infer a wrong mask
# when the confidence is low. In the future, we will add the mask inference
# rules for ReLU6 to break this range constraint.
with torch.no_grad():
for tensor in self.dummy_input:
if isinstance(tensor, torch.Tensor) and len(tensor.size()) > 0:
# if the tensor is a scalar, then skip this tensor
randomize_tensor(tensor, start, end)
for para in self.weights:
randomize_tensor(self.weights[para].data, start, end)
def zero_grad(self):
"""
Set the gradient of the weight, input tensor to be zeros.
"""
with torch.no_grad():
# set the weight's gradient to zero
if isinstance(self.module, nn.Module):
self.module.zero_grad()
# also zero the gradient of the input tensors
for tensor in self.dummy_input:
if isinstance(tensor, torch.Tensor):
if tensor.grad is not None:
tensor.grad.data.zero_()
def requires_grad_(self, flag=True):
"""
Set the requires_grad of input tensor and parameters to flag.
"""
for t_in in self.dummy_input:
if isinstance(t_in, torch.Tensor) and t_in.dtype in torch_float_dtype:
# only float type can require the gradient
# enable the auto gradient
t_in.requires_grad_(flag)
for para_name in self.weights:
if self.weights[para_name].dtype in torch_float_dtype:
self.weights[para_name].requires_grad_(flag)
def apply_mask(self):
self.__apply_input_mask()
self.__apply_weight_mask()
def __apply_input_mask(self):
"""
Apply the mask of the input tensor.
"""
with torch.no_grad():
# apply the input mask
for tid, in_tensor in enumerate(self.dummy_input):
if isinstance(in_tensor, torch.Tensor) and self.in_masks[tid] is not None:
in_tensor.data = in_tensor.data * \
self.in_masks[tid] + \
(1-self.in_masks[tid]) * self.in_constants[tid]
def __apply_weight_mask(self):
"""
Apply the weight mask of this module.
"""
with torch.no_grad():
# apply the weight mask
for para in self.weights:
if para in self.weight_mask:
self.weights[para].data *= self.weight_mask[para].data
def isconstants(self, tout):
"""
Find the constants in the tensor tout. This function return a mask tensor that
indicates if a value in tout is a constant, and return one more tensor to indicate
that the values of the constant.
Paramters
---------
tout: torch.Tensor
The target output tensor to find the constants
Returns
-------
mask: torch.Tensor
The mask tensor(same shape with tout) that indicates that whether
the correponding value is a constant.
constant: torch.Tensor
The mask tensot(same shape with tout) that indicates the values of
the constants in the tout.
"""
assert isinstance(tout, torch.Tensor)
out_mask = torch.ones_like(tout)
constant = torch.zeros_like(tout)
# judge if tout is a scalar(tensor that only have one value)
if len(tout.size()) == 0:
# tout is a scalar tensor, for the scalar tensor, we take
# this scalar as a constant, usually, the scalar tensor is returned
# by the size() function
constant = tout
return out_mask, constant
if tout.dtype in torch_integer_dtype:
# Pytorch cannot use torch.mean and torch.std to process
# intergers :( , so if dtype of the input tensor is integer, we need
# check if is the constant by ourselves
# Note: the first dimension should be the batch dimension
same = tout[:] == tout[0]
reduced = torch.sum(same, dim=0)
is_constant = reduced == tout.size(0)
out_mask[:, is_constant] = 0
constant[:, is_constant] = tout[0][is_constant]
else:
# calculate the std of the output among batch dimension
std = torch.std(tout, dim=0)
# calculate the mean value of the output among the batch dimension
mean = torch.mean(tout, dim=0)
mask_pos = std < STD_DELTA
out_mask[:, mask_pos] = 0
constant[:, mask_pos] = mean[mask_pos]
return out_mask, constant
def update_indirect_sparsity(self):
"""
This function will update the indirect sparsity. To explain what's
indirect sparsity, for example, there is two tensors TA and TB, and
we perform the calculation: TC = TA x TB in which TC is also a tensor.
Once some values in TA are masked to zeros, then the corresponding
positions in TB are also potential sparsities, because these have no
effect of the final output(the gradient of these positions in TB equal
to 0 all the time). This function it to fine the potential sparsity caused
by other sparsity(we call it indirect sparsity here). Basically we can find
these potential sparsity through gradient.
"""
# Each node only update the output mask when we backwards
# update the output mask, this is because that some op may
# have the broadcast operation, for example, OP A's output
# tensor may be taken by two OPs(B, C) as inputs. So we cannot
# directly update the input mask at the OP B or C. We can only
# update the mask of C's output tensor only when B and C are
# already updated(gradient are already calculated and added to
# C's output tensor).
# Besides, updating the mask of C's output tensor equals to updating
# the input mask of OP B and C.
if isinstance(self.output, torch.Tensor) and self.output.grad is not None:
# if output have gradient which means this node has successor
# nodes and the successor nodes have already update their indirect
# sparsity
# we can mask the values whose gradient is always zeros
gradient_sum = torch.sum(torch.abs(self.output.grad.data), dim=0)
_grad_zero = gradient_sum == 0
for batchid in range(self.output.size(0)):
# set the same mask value for the whole batche
self.output_mask[batchid][_grad_zero] = 0
elif isinstance(self.output, tuple) or isinstance(self.output, list):
assert isinstance(self.output_mask, (tuple, list))
for oid, tout in enumerate(self.output):
errmsg = 'The output only support tensor/list of tensors'
assert isinstance(tout, torch.Tensor), errmsg
gradient_sum = torch.sum(
torch.abs(self.output.grad.data), dim=0)
_grad_zero = gradient_sum == 0
for batchid in range(self.output.size(0)):
# set the same mask value for the whole batch
self.output_mask[oid][batchid][_grad_zero] = 0
self.requires_grad_(True)
# Forward inference with auto gradient enabled
# Note: tensors that need gradient cannot be used in the in-place operator
self.random_init()
self.apply_mask()
# Some operator may have the in_place operations, so we need to clone the input
# before passing to the self.module
tmp_dummy_input = [x.clone() if isinstance(
x, torch.Tensor) else x for x in self.dummy_input]
output = self.module(*tmp_dummy_input)
if output.grad_fn is None:
# the output does not have the gradient function
return
# Note: output maybe tensor or list/tuple of tensors
if isinstance(output, torch.Tensor):
output.backward(self.output_mask)
elif isinstance(output, list) or isinstance(output, tuple):
for tid, t_out in enumerate(output):
t_out.backward(self.output_mask[tid])
# update the sparsity of the paramters
for para_name in self.weights:
grad_zero = self.weights[para_name].grad.data == 0
self.weight_mask[para_name][grad_zero] = 0
def update_direct_sparsity(self):
# we don't need the gradient in the forward inference
out_mask = None
constant = None
with torch.no_grad():
# Note: we need randomly init the input one more time here!
# Because some operation have the in-place operation, such as relu_,
# the in-place operation may modify or write 0s into the dummy_input
self.random_init()
# apply the mask for the input tensor and the weight tensor
self.apply_mask()
# Note: due to the in-place operator, such as relu_,
# ori_out may be the same tensor with dummy_input,
# so we use clone and detach to create a new tensor with
# the same values.
out = self.module(*self.dummy_input)
if isinstance(out, torch.Tensor):
out_mask, constant = self.isconstants(out.clone().detach())
elif isinstance(out, tuple) or isinstance(out, list):
out_mask = []
constant = []
for tout in out:
_mask, _constant = self.isconstants(tout.clone().detach())
out_mask.append(_mask)
constant.append(_constant)
else:
_logger.warning(
'Only support the OP whose output is tensor/tuple of tensor/list of tensor')
# We also need random the parameters of the module, because if the weight of the model has
# a unmasked 0, then our out sparsity inference may be wrong
# However, after radomizing the weight/parameters, the constant in the output tensors may
# be different from the constants that calculated from its original stata_dict. However,
# so to get the right constant to eliminate the bias between model before and after sparsity
# inference, we need to reload its state_dict and recalculate the constant
# Currently we also get the constant values at the same time when infering the mask, in
# the future, we will separate the constant inference into a single graph pass.
if len(self.weights) > 0 and self.state_dict is not None:
self.module.load_state_dict(self.state_dict)
# apply weight mask
self.__apply_weight_mask()
out = self.module(*self.dummy_input).clone().detach()
if isinstance(out, torch.Tensor):
constant = torch.zeros_like(out)
constant_pos = out_mask == 0
constant[constant_pos] = out[constant_pos]
elif isinstance(out, (list, tuple)):
constant = []
for i, tout in enumerate(out):
_tmp = torch.zeros_like(tout)
sparsity_pos = out_mask[i] == 0
_tmp[sparsity_pos] = tout[sparsity_pos]
constant.append(_tmp)
if isinstance(out_mask, torch.Tensor):
assert isinstance(self.output_mask, torch.Tensor)
self.output_mask *= out_mask
elif isinstance(out_mask, list):
for i, _ in enumerate(out_mask):
self.output_mask[i] *= out_mask[i]
else:
_logger.warning('There is no output sparsity')
# also save the out_constant
self.out_constant = constant
def get_masks(self):
return (self.in_masks, self.output_mask, self.weight_mask)
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
"""
For each operation or module, there are two functions.
One is given output shape, infer its input shape and initialization parameters (e.g., weight's shape)
The other is given input shape, infer its output shape and initialization parameters (e.g., weight's shape)
"""
import logging
import torch
_logger = logging.getLogger(__name__)
conv_prune_dim = -1
def set_conv_prune_dim(dim):
"""
Parameters:
dim: int
0: filter pruning
1: channel pruning
"""
global conv_prune_dim
conv_prune_dim = dim
class CoarseMask:
"""
Coarse grained mask for a given tensor, here tensor could be weights,
input tensor, or output tensor
"""
def __init__(self, num_dim):
"""
Parameters
----------
num_dim : int
The number of dimensions of the tensor that will be masked
"""
self.mask_index = [None for _ in range(num_dim)]
def add_index_mask(self, dim, index):
"""
Add mask for the specified dimension
Parameters
----------
dim : int
The dimension to add mask
index : tensor
The mask for this dimension, its a 1 dimension tensor which specifies
the index of the elements that are not pruned
"""
self.mask_index[dim] = index
@staticmethod
def merge_index(index_a, index_b):
"""
Parameters
----------
index_a : tensor
One index (1-dimension) tensor
index_b : tensor
The other index (1-dimension) tensor
Returns
-------
tensor
The merged index (1-dimension) tensor
Note that: the output tensor will be moved
to the same device as index_a.
"""
device = index_a.device
s = set()
for num in index_a.tolist():
# we need to transfer the tensor to list here
# first, directly traversing the tensor by for
# loop will return the list of tensor(x) object,
# even the value are the same, but they are different
# tensor objects, so the set will contains multiple
# tensor objects that has the same value. For example
# for num in torch.ones(2):
# s.add(num)
# s will be {tensor(1), tensor(1)}
s.add(num)
for num in index_b.tolist():
s.add(num)
# move the output tensor to the same device with index_a
return torch.tensor(sorted(s)).to(device) # pylint: disable=not-callable
def merge(self, cmask):
"""
Merge another CoarseMask
Parameters
----------
cmask : CoarseMask
Another CoarseMask to merge
Returns
-------
list
The member variable ```mask_index```
"""
assert isinstance(cmask, CoarseMask)
assert len(self.mask_index) == len(cmask.mask_index), \
"Only masks with the same number of dimensions can be merged"
for i, index in enumerate(self.mask_index):
if index is None:
self.mask_index[i] = cmask.mask_index[i]
elif cmask.mask_index[i] is not None:
self.mask_index[i] = CoarseMask.merge_index(self.mask_index[i],
cmask.mask_index[i])
return self.mask_index
def __repr__(self):
return 'mask_index: {}'.format(self.mask_index)
def eq_on_dim(self, other, dim):
assert isinstance(other, CoarseMask)
if self.mask_index[dim] is None and other.mask_index[dim] is None:
return True
elif isinstance(self.mask_index[dim], torch.Tensor) \
and isinstance(other.mask_index[dim], torch.Tensor):
return torch.equal(self.mask_index[dim], other.mask_index[dim])
else:
return False
def __eq__(self, other):
assert isinstance(other, CoarseMask)
if len(self.mask_index) != len(other.mask_index):
return False
for i in range(len(self.mask_index)):
if not self.eq_on_dim(other, i):
return False
return True
def __lt__(self, other):
"""
Judge if the mask is a subset of another CoarseMask.
"""
assert isinstance(other, CoarseMask)
for dim, _ in enumerate(self.mask_index):
# if self has more dimensions
if dim >= len(other.mask_index):
return False
if self.mask_index[dim] is None:
# if no mask on this dimension, then we have less
# masks then the other CoraseMask.
continue
elif other.mask_index[dim] is None:
return False
else:
s1 = set(self.mask_index[dim].tolist())
s2 = set(other.mask_index[dim].tolist())
if not s1 < s2:
return False
return True
def __le__(self, other):
"""
Return if self's mask is less or equal to other's mask.
"""
assert isinstance(other, CoarseMask)
if self.__lt__(other) or self.__eq__(other):
return True
return False
def __ne__(self, other):
return not self.__eq__(other)
class ModuleMasks:
"""
The masks of a module, including the masks for weights, inputs, output
"""
def __init__(self, module_name, module=None):
"""
Parameters
----------
module_name : str
The name of the module or function
"""
self.module_name = module_name
self.module = module
self.param_masks = dict()
self.input_mask = None
self.output_mask = None
def set_param_masks(self, name, mask):
"""
Parameters
----------
name : str
The name of the weight
mask : CoarseMask
The mask for this weight
"""
self.param_masks[name] = mask
def set_input_mask(self, mask):
"""
Parameters
----------
mask : CoarseMask
The mask for input
"""
self.input_mask = mask
def set_output_mask(self, mask):
"""
Parameters
----------
mask : CoarseMask
The mask for output
"""
self.output_mask = mask
def __repr__(self):
return 'module_name: {}, input_mask: {}, output_mask: {}, param_masks: {}'.format(
self.module_name, self.input_mask, self.output_mask, self.param_masks
)
"""
Infer input and output shape of a module/function from its weight mask
"""
infer_from_mask = {
'BatchNorm2d': lambda module_masks, mask: batchnorm2d_mask(module_masks, mask),
'Conv2d': lambda module_masks, mask: conv2d_mask(module_masks, mask),
'ConvTranspose2d': lambda module_masks, mask: convtranspose2d_mask(module_masks, mask),
'Linear': lambda module_masks, mask, shape: linear_mask(module_masks, mask, shape)
}
"""
Infer output and weight shape of a module/function from its input shape
"""
infer_from_inshape = {
'ReLU': lambda module_masks, mask: relu_inshape(module_masks, mask),
'ReLU6': lambda module_masks, mask: relu_inshape(module_masks, mask),
'Sigmoid': lambda module_masks, mask: relu_inshape(module_masks, mask),
'aten::relu': lambda module_masks, mask: relu_inshape(module_masks, mask),
'aten::tanh': lambda module_masks, mask: relu_inshape(module_masks, mask),
'aten::tanh_': lambda module_masks, mask: relu_inshape(module_masks, mask),
'aten::hardtanh': lambda module_masks, mask: relu_inshape(module_masks, mask),
'aten::hardtanh_': lambda module_masks, mask: relu_inshape(module_masks, mask),
'aten::relu_': lambda module_masks, mask: relu_inshape(module_masks, mask),
'aten::sigmoid': lambda module_masks, mask: relu_inshape(module_masks, mask),
'Conv2d': lambda module_masks, mask: conv2d_inshape(module_masks, mask),
'ConvTranspose2d': lambda module_masks, mask: convtranspose2d_inshape(module_masks, mask),
'MaxPool2d': lambda module_masks, mask: maxpool2d_inshape(module_masks, mask),
'aten::max_pool2d': lambda module_masks, mask: maxpool2d_inshape(module_masks, mask),
'aten::avg_pool2d': lambda module_masks, mask: maxpool2d_inshape(module_masks, mask),
'aten::adaptive_avg_pool2d': lambda module_masks, mask: maxpool2d_inshape(module_masks, mask),
'AvgPool2d': lambda module_masks, mask: maxpool2d_inshape(module_masks, mask),
'AdaptiveAvgPool2d': lambda module_masks, mask: maxpool2d_inshape(module_masks, mask),
'aten::size': lambda module_masks, mask: size_inshape(module_masks, mask),
'aten::view': lambda module_masks, mask, shape: view_inshape(module_masks, mask, shape),
'aten::reshape': lambda module_masks, mask, shape: view_inshape(module_masks, mask, shape),
# support only start_dim=1
'aten::flatten': lambda module_masks, mask, shape: view_inshape(module_masks, mask, shape),
'Linear': lambda module_masks, mask: linear_inshape(module_masks, mask),
'BatchNorm2d': lambda module_masks, mask: batchnorm2d_inshape(module_masks, mask),
'aten::add_': lambda module_masks, mask: add_inshape(module_masks, mask),
'aten::add': lambda module_mask, mask: add_inshape(module_mask, mask),
# mul has the similar behaviour with add, they both request
# the input tesors to have the same shape
'aten::mul': lambda module_mask, mask: add_inshape(module_mask, mask),
'aten::mul_': lambda module_mask, mask: add_inshape(module_mask, mask),
'aten::cat': lambda module_mask, mask, cat_info, last_visited: cat_inshape(module_mask, mask, cat_info, last_visited),
'aten::mean': lambda module_masks, mask, shape: mean_inshape(module_masks, mask, shape),
'Dropout': lambda module_masks, mask: dropout_inshape(module_masks, mask),
'Dropout2d': lambda module_masks, mask: dropout_inshape(module_masks, mask),
'aten::dropout': lambda module_masks, mask: dropout_inshape(module_masks, mask),
'aten::detach': lambda module_masks, mask: dropout_inshape(module_masks, mask)
}
"""
Infer input and weight shape of a module/function from its output shape
"""
infer_from_outshape = {
'Conv2d': lambda module_masks, mask: conv2d_outshape(module_masks, mask),
'ConvTranspose2d': lambda module_masks, mask: convtranspose2d_outshape(module_masks, mask),
'BatchNorm2d': lambda module_masks, mask: batchnorm2d_outshape(module_masks, mask),
'MaxPool2d': lambda module_masks, mask: maxpool2d_outshape(module_masks, mask),
'aten::max_pool2d': lambda module_masks, mask: maxpool2d_outshape(module_masks, mask),
'aten::avg_pool2d': lambda module_masks, mask: maxpool2d_outshape(module_masks, mask),
'aten::adaptive_avg_pool2d': lambda module_masks, mask: maxpool2d_outshape(module_masks, mask),
'AvgPool2d': lambda module_masks, mask: maxpool2d_outshape(module_masks, mask),
'AdaptiveAvgPool2d': lambda module_masks, mask: maxpool2d_outshape(module_masks, mask),
'ReLU': lambda module_masks, mask: relu_outshape(module_masks, mask),
'ReLU6': lambda module_masks, mask: relu_outshape(module_masks, mask),
'aten::relu': lambda module_masks, mask: relu_outshape(module_masks, mask),
'aten::tanh': lambda module_masks, mask: relu_outshape(module_masks, mask),
'aten::tanh_': lambda module_masks, mask: relu_outshape(module_masks, mask),
'aten::hardtanh': lambda module_masks, mask: relu_outshape(module_masks, mask),
'aten::hardtanh_': lambda module_masks, mask: relu_outshape(module_masks, mask),
'aten::relu_': lambda module_masks, mask: relu_outshape(module_masks, mask),
'aten::add_': lambda module_masks, mask: add_outshape(module_masks, mask),
'aten::add': lambda module_mask, mask: add_outshape(module_mask, mask),
'aten::flatten': lambda module_mask, mask, shape: view_outshape(module_mask, mask, shape),
'aten::view': lambda module_masks, mask, shape: view_outshape(module_masks, mask, shape),
'aten::reshape': lambda module_masks, mask, shape: view_outshape(module_masks, mask, shape),
'aten::mean': lambda module_masks, mask, shape: mean_outshape(module_masks, mask, shape),
'Dropout': lambda module_masks, mask: dropout_outshape(module_masks, mask),
'Dropout2d': lambda module_masks, mask: dropout_outshape(module_masks, mask),
'aten::dropout': lambda module_masks, mask: dropout_outshape(module_masks, mask),
'aten::detach': lambda module_masks, mask: dropout_outshape(module_masks, mask)
}
def dropout_inshape(module_masks, mask):
if module_masks.input_mask is None:
module_masks.set_input_mask(mask)
module_masks.set_output_mask(mask)
return module_masks.output_mask
# if alreay visited
assert module_masks.input_mask <= mask
# It should be the same, we pass the masks by the reference(not the value),
# so they acutually are two references of the same object(mask,
# module_masks.input_mask). So we should continue pass the mask
# to the following nodes even module_masks.input_mask == mask.
# if pass the mask by copy.deepcopy(), then we can stop when
# module_masks.input_mask == mask.
# if module_masks.input_mask == mask:
# return None
module_masks.set_input_mask(mask)
module_masks.set_output_mask(mask)
return module_masks.output_mask
def dropout_outshape(module_masks, mask):
if module_masks.output_mask is None:
module_masks.set_output_mask(mask)
module_masks.set_input_mask(mask)
return module_masks.input_mask
# if alreay visited
assert all(module_masks.output_mask.mask_index[1] == mask.mask_index[1])
return module_masks.output_mask
def cat_inshape(module_masks, mask, cat_info, last_visited):
"""
Inference the output mask of the cat operation from the
input mask.
Parameters
----------
module_masks : ModuleMasks
The ModuleMasks instance of the Conv2d
mask : CoarseMask
The mask of its input tensor
cat_info: dict
Dict object that records the necessary information
of cat operation, such as the order of the input
tensors.
last_visited: str
The unique_name of the last visited node group.
Returns
-------
CoarseMask
The mask of its output tensor
"""
assert isinstance(mask, CoarseMask)
out_shape = cat_info['out_shape']
cat_dim = cat_info['cat_dim']
in_order = cat_info['in_order']
in_shape = cat_info['in_shape']
if module_masks.output_mask is None:
# First visit to this cat node
# initialize the mask based on
# the number of the output channel.
output_mask = CoarseMask(num_dim=len(out_shape))
for dim, _ in enumerate(out_shape):
if dim == cat_dim:
if mask.mask_index[dim] is None:
continue
device = mask.mask_index[dim].device
# calculate the offset of the mask
pos = in_order.index(last_visited)
offsets = [in_shape[i][cat_dim]
for i, _ in enumerate(in_shape)]
offset = 0
for i in range(pos):
offset += offsets[i]
_tmp_mask = (mask.mask_index[dim] + offset).to(device)
output_mask.mask_index[dim] = _tmp_mask
else:
# directly copy the mask
if mask.mask_index[dim] is not None:
output_mask.mask_index[dim] = mask.mask_index[dim].data.clone(
)
module_masks.set_output_mask(output_mask)
return module_masks.output_mask
# If this cat node is already visited, we need
# validating if the mask is legel, for cat operation,
# the mask on the 'cat_dim' dimension should be stitched
# together. In the other dimensions, the mask should be
# the same, else the mask is not legal.
for dim, _ in enumerate(out_shape):
if dim == cat_dim:
if mask.mask_index[dim] is None:
continue
pos = in_order.index(last_visited)
offsets = [in_shape[i][cat_dim] for i, _ in enumerate(in_shape)]
offset = 0
for i in range(pos):
offset += offsets[i]
device = mask.mask_index[dim].device
new_mask = mask.mask_index[dim] + offset
module_masks.output_mask.mask_index[dim] = CoarseMask.merge_index(
module_masks.output_mask.mask_index[dim], new_mask).to(device)
else:
assert module_masks.output_mask.eq_on_dim(mask, dim)
return module_masks.output_mask
def add_inshape(module_masks, mask):
"""
Inference the output mask of the add operation from the
input mask.
"""
assert isinstance(mask, CoarseMask)
if module_masks.input_mask is None:
module_masks.set_input_mask(mask)
module_masks.set_output_mask(mask)
# module_masks.input_mask = mask
return mask
# If alreay visited, validate if have the conflict
# if the mask is different with previous input_mask
# then there is a mask confilct.
if mask != module_masks.input_mask:
raise Exception('Mask conflict happenes!')
return None
def add_outshape(module_masks, mask):
"""
Inference the input mask of the add operation from the
output mask.
"""
assert isinstance(mask, CoarseMask)
if module_masks.output_mask is None:
module_masks.set_output_mask(mask)
module_masks.set_input_mask(mask)
return mask
else:
assert all(
module_masks.output_mask.mask_index[1] == mask.mask_index[1])
return mask
def batchnorm2d_inshape(module_masks, mask):
"""
We assume only the second dimension has coarse grained mask
Parameters
----------
module_masks : ModuleMasks
The ModuleMasks instance of the batchnorm2d
mask : CoarseMask
The mask of its input tensor
Returns
-------
CoarseMask
The mask of its output tensor
"""
assert isinstance(mask, CoarseMask)
assert mask.mask_index[1] is not None
assert mask.mask_index[0] is None
assert mask.mask_index[2] is None
assert mask.mask_index[3] is None
module_masks.set_input_mask(mask)
module_masks.set_output_mask(mask)
weight_cmask = CoarseMask(num_dim=1)
weight_cmask.add_index_mask(dim=0, index=mask.mask_index[1])
module_masks.set_param_masks('weight', weight_cmask)
module_masks.set_param_masks('bias', weight_cmask)
return mask
def batchnorm2d_outshape(module_masks, mask):
"""
We assume only the second dimension has coarse grained mask
Parameters
----------
module_masks : ModuleMasks
The ModuleMasks instance of the batchnorm2d
mask : CoarseMask
The mask of its input tensor
Returns
-------
CoarseMask
The mask of its output tensor
"""
assert isinstance(mask, CoarseMask)
assert len(mask.mask_index) in [2, 4]
assert mask.mask_index[1] is not None
assert mask.mask_index[0] is None
module_masks.set_input_mask(mask)
module_masks.set_output_mask(mask)
weight_cmask = CoarseMask(num_dim=1)
weight_cmask.add_index_mask(dim=0, index=mask.mask_index[1])
module_masks.set_param_masks('weight', weight_cmask)
module_masks.set_param_masks('bias', weight_cmask)
return mask
def linear_inshape(module_masks, mask):
"""
Coarse grained input mask does not change the shape of weights and output tensor
Parameters
----------
module_masks : ModuleMasks
The ModuleMasks instance of the linear
mask : CoarseMask
The mask of its input tensor
Returns
-------
CoarseMask
The mask of its output tensor, ```None``` means shape of output tensor is not changed
"""
assert isinstance(mask, CoarseMask)
assert mask.mask_index[0] is None
if module_masks.input_mask is not None:
assert module_masks.input_mask <= mask
module_masks.set_input_mask(mask)
return None
def view_inshape(module_masks, mask, shape):
"""
This is a limited support
TODO: consider replace tensor.view with nn.Flatten, because tensor.view is not
included in module, thus, cannot be replaced by our framework.
Parameters
----------
module_masks : ModuleMasks
The ModuleMasks instance of the ```view``` op
mask : CoarseMask
The mask of its input tensor
shape : dict
Original shape of its input and output tensors
Returns
-------
CoarseMask
The mask of its output tensor
"""
# NOTE: the case constrained by the following four asserts
assert shape['in_shape'][0] == shape['out_shape'][0]
assert len(shape['in_shape']) == 4
assert len(shape['out_shape']) == 2
assert shape['out_shape'][1] == shape['in_shape'][1] * \
shape['in_shape'][2]*shape['in_shape'][3]
assert isinstance(mask, CoarseMask)
assert mask.mask_index[1] is not None
assert mask.mask_index[0] is None
assert mask.mask_index[2] is None
assert mask.mask_index[3] is None
# due to the cat operation, the same node may be
# accessed more than once
if module_masks.input_mask is not None:
assert module_masks.input_mask <= mask
module_masks.set_input_mask(mask)
output_cmask = CoarseMask(num_dim=2)
index = []
step_size = shape['in_shape'][2] * shape['in_shape'][3]
for loc in mask.mask_index[1]:
index.extend([loc * step_size + i for i in range(step_size)])
output_cmask.add_index_mask(dim=1, index=torch.tensor(index).to(mask.mask_index[1].device)) # pylint: disable=not-callable
module_masks.set_output_mask(output_cmask)
return output_cmask
def view_outshape(module_masks, mask, shape):
"""
Parameters
----------
module_masks : ModuleMasks
The ModuleMasks instance of the ```view``` op
mask : CoarseMask
The mask of its output tensor
shape : dict
Original shape of its input and output tensors
Returns
-------
CoarseMask
The mask of its input tensor
"""
# NOTE: the case constrained by the following four asserts
assert shape['in_shape'][0] == shape['out_shape'][0]
assert len(shape['in_shape']) == 4
assert len(shape['out_shape']) == 2
assert shape['out_shape'][1] == shape['in_shape'][1] * \
shape['in_shape'][2]*shape['in_shape'][3]
assert isinstance(mask, CoarseMask)
assert mask.mask_index[1] is not None
assert mask.mask_index[0] is None
module_masks.set_output_mask(mask)
input_cmask = CoarseMask(num_dim=4)
index = set()
step_size = shape['in_shape'][2] * shape['in_shape'][3]
for loc in mask.mask_index[1]:
index.add(loc // step_size)
index = sorted(list(index))
input_cmask.add_index_mask(dim=1, index=torch.tensor(index).to(mask.mask_index[1].device)) # pylint: disable=not-callable
module_masks.set_input_mask(input_cmask)
return input_cmask
def size_inshape(module_masks, mask):
"""
No need to do anything for this ```size``` op
"""
return None
def mean_inshape(module_masks, mask, shape):
"""
Similar to view operation, currently mask inference only supports
the mean operation on the 3rd and 4th dimensions.
"""
assert shape['in_shape'][0] == shape['out_shape'][0]
assert shape['out_shape'][1] == shape['in_shape'][1]
assert len(shape['in_shape']) == 4
assert len(shape['out_shape']) == 2
assert isinstance(mask, CoarseMask)
assert mask.mask_index[1] is not None
assert mask.mask_index[0] is None
assert mask.mask_index[2] is None
assert mask.mask_index[3] is None
module_masks.set_input_mask(mask)
output_cmask = CoarseMask(num_dim=2)
output_cmask.add_index_mask(dim=1, index=mask.mask_index[1])
module_masks.set_output_mask(output_cmask)
return output_cmask
def mean_outshape(module_masks, mask, shape):
"""
Similar to view operation, currently mask inference only supports
the mean operation on the 3rd and 4th dimensions.
"""
assert shape['in_shape'][0] == shape['out_shape'][0]
assert shape['out_shape'][1] == shape['in_shape'][1]
assert len(shape['in_shape']) == 4
assert len(shape['out_shape']) == 2
assert isinstance(mask, CoarseMask)
assert mask.mask_index[1] is not None
assert mask.mask_index[0] is None
module_masks.set_output_mask(mask)
input_cmask = CoarseMask(num_dim=4)
input_cmask.add_index_mask(dim=1, index=mask.mask_index[1])
module_masks.set_input_mask(input_cmask)
return input_cmask
def maxpool2d_inshape(module_masks, mask):
"""
Assume only the second dimension is masked
Parameters
----------
module_masks : ModuleMasks
The ModuleMasks instance of the maxpool2d
mask : CoarseMask
The mask of its input tensor
Returns
-------
CoarseMask
The mask of its output tensor
"""
assert isinstance(mask, CoarseMask)
assert mask.mask_index[1] is not None
assert mask.mask_index[0] is None
assert mask.mask_index[2] is None
assert mask.mask_index[3] is None
if module_masks.input_mask is not None:
assert module_masks.input_mask <= mask
# assert module_masks.input_mask is None
module_masks.set_input_mask(mask)
module_masks.set_output_mask(mask)
return mask
def maxpool2d_outshape(module_masks, mask):
"""
Assume only the second dimension is masked
Parameters
----------
module_masks : ModuleMasks
The ModuleMasks instance of the maxpool2d
mask : CoarseMask
The mask of its input tensor
Returns
-------
CoarseMask
The mask of its output tensor
"""
assert isinstance(mask, CoarseMask)
assert mask.mask_index[1] is not None
assert mask.mask_index[0] is None
module_masks.set_input_mask(mask)
module_masks.set_output_mask(mask)
return mask
def relu_inshape(module_masks, mask):
"""
Parameters
----------
module_masks : ModuleMasks
The ModuleMasks instance of the relu
mask : CoarseMask
The mask of its input tensor
Returns
-------
CoarseMask
The mask of its output tensor
"""
assert isinstance(mask, CoarseMask)
if module_masks.input_mask is not None:
# mask conflict should be solved before speedup
assert module_masks.input_mask <= mask
# assert module_masks.input_mask is None, "A relu op can only be processed once"
module_masks.set_input_mask(mask)
module_masks.set_output_mask(mask)
return mask
def relu_outshape(module_masks, mask):
"""
Parameters
----------
module_masks : ModuleMasks
The ModuleMasks instance of the relu
mask : CoarseMask
The mask of its input tensor
Returns
-------
CoarseMask
The mask of its output tensor
"""
assert isinstance(mask, CoarseMask)
if module_masks.output_mask is not None:
# mask conflict should be solved before speedup
assert all(
module_masks.output_mask.mask_index[1] == mask.mask_index[1])
module_masks.set_input_mask(mask)
module_masks.set_output_mask(mask)
return mask
def batchnorm2d_mask(module_masks, mask):
"""
Infer input and output shape from weight mask
Parameters
----------
module_masks : ModuleMasks
The ModuleMasks instance of the batchnorm2d
mask : dict
The mask of its weights, from the user provided mask file
Returns
-------
CoarseMask, CoarseMask
The mask of its input tensor, the mask of its output tensor
"""
assert 'weight' in mask and 'bias' in mask
sum_mask = mask['weight'] + mask['bias']
nonzero_index = torch.nonzero(sum_mask, as_tuple=True)[0]
# infer shape of parameters
param_cmask = CoarseMask(num_dim=1)
param_cmask.add_index_mask(dim=0, index=nonzero_index)
module_masks.set_param_masks('weight', param_cmask)
module_masks.set_param_masks('bias', param_cmask)
# infer shape of input tensor
input_cmask = CoarseMask(num_dim=4)
input_cmask.add_index_mask(dim=1,
index=torch.nonzero(mask['weight'], as_tuple=True)[0])
module_masks.set_input_mask(input_cmask)
# infer shape of output tensor
output_cmask = CoarseMask(num_dim=4)
output_cmask.add_index_mask(dim=1, index=nonzero_index)
module_masks.set_output_mask(output_cmask)
return input_cmask, output_cmask
def linear_mask(module_masks, mask, shape):
"""
Infer input and output shape from weight mask with limitations:
Only support infer input mask
Parameters
----------
module_masks : ModuleMasks
The ModuleMasks instance of the Linear
mask : dict
The mask of its weights, from the user provided mask file
shape: dict
Shape of its input and output tensors
Returns
-------
CoarseMask, CoarseMask
The mask of its input tensor, the mask of its output tensor
"""
assert 'weight' in mask
num_input_dim = len(shape['in_shape'])
# Input data of Linear module can have multiple dimensions.
# here we only support infer coarse mask on the first dimension (dimension 0)
nonzero_index = torch.nonzero(mask['weight'].sum(0), as_tuple=True)[0]
# infer shape of input tensor
input_cmask = CoarseMask(num_dim=num_input_dim)
input_cmask.add_index_mask(dim=num_input_dim-1, index=nonzero_index)
module_masks.set_input_mask(input_cmask)
return input_cmask, None
def conv2d_mask(module_masks, mask):
"""
Infer input and output shape from weight mask
Parameters
----------
module_masks : ModuleMasks
The ModuleMasks instance of the conv2d
mask : dict
The mask of its weights, from the user provided mask file
Returns
-------
CoarseMask, CoarseMask
The mask of its input tensor, the mask of its output tensor
"""
def convert_to_coarse_mask(mask, dim=0):
"""
Parameters
----------
mask : dict
Weight mask from user provided mask file
dim: int
0: filter pruning
1: channel pruning
Returns
-------
LongTensor, CoarseMask, CoarseMask
Index of the masked dimension, weight mask, bias mask
"""
assert 'weight' in mask
assert isinstance(mask['weight'], torch.Tensor)
assert dim in [0, 1]
weight_mask = mask['weight']
sum_idx = (1, 2, 3) if dim == 0 else (0, 2, 3)
index = torch.nonzero(weight_mask.abs().sum(
sum_idx) != 0, as_tuple=True)[0]
index = index.long().to(weight_mask.device)
weight_cmask = CoarseMask(num_dim=4)
weight_cmask.add_index_mask(dim=dim, index=index)
bias_cmask = None
if dim == 0 and 'bias' in mask and mask['bias'] is not None:
bias_index = torch.nonzero(mask['bias'], as_tuple=True)[0]
assert torch.all(torch.eq(index, bias_index)), \
"bias mask should be consistent with weight mask"
bias_cmask = CoarseMask(num_dim=1)
bias_cmask.add_index_mask(dim=0, index=bias_index)
return index, weight_cmask, bias_cmask
index, weight_cmask, bias_cmask = convert_to_coarse_mask(
mask, dim=conv_prune_dim)
if index is None:
# TODO: fine grained mask speedup
return None, None
# deal with coarse grain mask
# mask conflict should be solved by fix_mask_conflict before speedup
if 'weight' in module_masks.param_masks:
assert module_masks.param_masks['weight'] == weight_cmask
else:
module_masks.set_param_masks('weight', weight_cmask)
if conv_prune_dim == 0:
module_masks.set_param_masks('bias', bias_cmask)
io_cmask = CoarseMask(num_dim=4)
io_cmask.add_index_mask(dim=1, index=index)
if conv_prune_dim == 0:
if module_masks.output_mask is None:
module_masks.set_output_mask(io_cmask)
else:
assert module_masks.output_mask == io_cmask
return None, module_masks.output_mask
else:
if module_masks.input_mask is None:
module_masks.set_input_mask(io_cmask)
else:
assert module_masks.input_mask == io_cmask
return module_masks.input_mask, None
def conv2d_inshape(module_masks, mask):
"""
Shape change of input tensor does not affect the shape of its output tensor
Parameters
----------
module_masks : ModuleMasks
The ModuleMasks instance of the conv2d
mask : CoarseMask
The mask of its input tensor
Returns
-------
CoarseMask
The mask of its output tensor
"""
assert isinstance(mask, CoarseMask)
if module_masks.input_mask is None:
module_masks.set_input_mask(mask)
else:
# the same conv layer may be accessed more
# than once, such as a concat operation.
# mask conflict should be solved by fix_mask_conflict before speedup
assert module_masks.input_mask == mask
# shape changes pass through depths wise conv layers
m = module_masks.module
if m.in_channels == m.out_channels == m.groups:
module_masks.output_mask = mask
module_masks.input_mask = mask
return mask
return None
def conv2d_outshape(module_masks, mask):
"""
Assume only the second dimension is masked
Parameters
----------
module_masks : ModuleMasks
The ModuleMasks instance of the conv2d
mask : CoarseMask
The mask of its output tensor
Returns
-------
CoarseMask
The mask of its input tensor
"""
assert isinstance(mask, CoarseMask)
assert mask.mask_index[1] is not None
assert mask.mask_index[0] is None
assert mask.mask_index[2] is None
assert mask.mask_index[3] is None
if module_masks.output_mask is None:
module_masks.output_mask = mask
else:
# mask conflict should be solved by fix_mask_conflict before speedup
# mask and module_masks.output_mask may have different number of dimensions
# since they could be passed by linear or conv2d
assert all(
module_masks.output_mask.mask_index[1] == mask.mask_index[1])
weight_cmask = CoarseMask(num_dim=4)
weight_cmask.add_index_mask(dim=0, index=mask.mask_index[1])
bias_cmask = CoarseMask(num_dim=1)
bias_cmask.add_index_mask(dim=0, index=mask.mask_index[1])
module_masks.set_param_masks('weight', weight_cmask)
module_masks.set_param_masks('bias', bias_cmask)
# shape changes pass through depths wise conv layers
m = module_masks.module
if m.in_channels == m.out_channels == m.groups:
module_masks.output_mask = mask
module_masks.input_mask = mask
return mask
return None
def convtranspose2d_mask(module_masks, mask):
# TODO support the Convtranspose2d Pruning for the L1FilterPruner
raise Exception(
"Current Filter pruner cannot prune the ConvTranspose2d, will support pruning ConvTranspose2d later")
def convtranspose2d_inshape(module_masks, mask):
"""
Shape change of input tensor does not affect the shape of its output tensor
Parameters
----------
module_masks : ModuleMasks
The ModuleMasks instance of the conv2d
mask : CoarseMask
The mask of its input tensor
Returns
-------
CoarseMask
The mask of its output tensor
"""
assert isinstance(mask, CoarseMask)
if module_masks.input_mask is None:
module_masks.set_input_mask(mask)
else:
# the same conv layer may be accessed more
# than once, such as a concat operation.
# mask conflict should be solved by fix_mask_conflict before speedup
assert module_masks.input_mask == mask
# shape changes pass through depths wise conv layers
m = module_masks.module
if m.in_channels == m.out_channels == m.groups:
module_masks.output_mask = mask
module_masks.input_mask = mask
return mask
return None
def convtranspose2d_outshape(module_masks, mask):
assert isinstance(mask, CoarseMask)
assert mask.mask_index[1] is not None
assert mask.mask_index[0] is None
assert mask.mask_index[2] is None
assert mask.mask_index[3] is None
if module_masks.output_mask is None:
module_masks.output_mask = mask
else:
# mask conflict should be solved by fix_mask_conflict before speedup
# mask and module_masks.output_mask may have different number of dimensions
# since they could be passed by linear or conv2d
assert all(
module_masks.output_mask.mask_index[1] == mask.mask_index[1])
weight_cmask = CoarseMask(num_dim=4)
# Note the memory layout of Convtranspose2d is C_in, C_out, k1, k2
weight_cmask.add_index_mask(dim=1, index=mask.mask_index[1])
bias_cmask = CoarseMask(num_dim=1)
bias_cmask.add_index_mask(dim=0, index=mask.mask_index[1])
module_masks.set_param_masks('weight', weight_cmask)
module_masks.set_param_masks('bias', bias_cmask)
# shape changes pass through depths wise conv layers
m = module_masks.module
if m.in_channels == m.out_channels == m.groups:
module_masks.output_mask = mask
module_masks.input_mask = mask
return mask
return None
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import re
import logging
from functools import partial
import torch
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
def translate_list(list_node, speedup=None):
"""
Get the list of values from the list construct node.
Parameters
---------
list_node: Torch.C.Value
The cpp node of the target list.
speedup: ModuleSpeed
The Module speedup module.
Returns
-------
values: list
The list of values in the target cpp list node.
"""
# the node that create the list
create_node = list_node.node()
assert create_node.kind() == 'prim::ListConstruct'
inputs = list(create_node.inputs())
values = []
for _i in inputs:
debugName = _i.debugName()
if speedup is not None and debugName in speedup.internal_result:
# this value is the result of the other nodes, such as
# ate::size
values.append(speedup.internal_result[debugName].item())
else:
# if the corresponding value is a constant
values.append(_i.toIValue())
return values
def parse_constant(cvalue, speedup):
"""
Parse the constant values from this Node
Parameters
----------
cvalue: Torch.C.Value
The cpp node of the target constant value.
speedup: ModelSpeedup
The Model speedup module.
Returns
-------
value: int/float/tensor
The constant values parsed from the node.
"""
logger.debug('Try to parse the constant value: %s', cvalue.debugName())
if cvalue.toIValue() is not None:
return cvalue.toIValue()
if cvalue.debugName() in speedup.internal_result:
return speedup.internal_result[cvalue.debugName()]
# Get the operator node of the this value
op_node = cvalue.node()
inputs = op_node.inputs()
input_values = [parse_constant(_i, speedup) for _i in inputs]
func = trans_from_jit_to_python[op_node.kind()](op_node, speedup)
return func(*input_values)
def dropout_python(node, speedup):
return torch.dropout
def flatten_python(node, speedup):
c_node = node.key_node
inputs = list(c_node.inputs())
start_dim = inputs[1].toIValue()
end_dim = inputs[2].toIValue()
new_flatten = partial(torch.flatten, start_dim=start_dim, end_dim=end_dim)
return new_flatten
def relu_inplace_python(node, speedup):
return torch.relu_
def relu_python(node, speedup):
return torch.relu
def sigmoid_python(node, speedup):
return torch.sigmoid
def mean_python(node, speedup):
c_node = node.key_node
inputs = list(c_node.inputs())
dim_list = translate_list(inputs[1], speedup)
keep_dim = inputs[2].toIValue()
new_mean = partial(torch.mean, dim=tuple(dim_list), keepdim=keep_dim)
return new_mean
def add_python(node, speedup):
c_node = node.key_node
inputs = list(c_node.inputs())
constant = None
for i in range(2):
input_i = inputs[i]
debug_name = input_i.debugName()
if debug_name not in speedup.internal_result:
# this input is a constant value
# TODO: what if this input is a constant tensor
if input_i.toIValue() is not None:
constant = parse_constant(input_i, speedup)
break
if constant is None:
return torch.add
else:
new_add = partial(torch.add, constant)
return new_add
def floor_div_python(node, speedup):
c_node = node.key_node
inputs = list(c_node.inputs())
divisor = inputs[1]
constant = None
if divisor.debugName() not in speedup.internal_result:
# divisor is a constant value/tensor
constant = parse_constant(divisor, speedup)
if constant is None:
return torch.floor_divide
else:
new_op = partial(torch.floor_divide, other=constant)
return new_op
def mul_python(node, speedup):
c_node = node.key_node
inputs = list(c_node.inputs())
constant = None
for i in range(2):
input_i = inputs[i]
debug_name = input_i.debugName()
if debug_name not in speedup.internal_result:
constant = parse_constant(input_i, speedup)
# both two inputs cannot be constants at the same time
break
if constant is None:
return torch.mul
else:
new_mul = partial(torch.mul, constant)
return new_mul
def transpose_python(node, speedup):
return torch.t
def transpose2_python(node, speedup):
c_node = node.key_node
inputs = list(c_node.inputs())
dim_1 = inputs[1].toIValue()
dim_2 = inputs[2].toIValue()
new_transpose = partial(torch.transpose, dim0=dim_1, dim1=dim_2)
return new_transpose
def matmul_python(node, speedup):
return torch.matmul
def div_python(node, speedup):
# The second input parameter of torch.div can be a
# tensor or a constant, if it is a constant, we need
# to return
c_node = node.key_node
inputs = list(c_node.inputs())
if inputs[1].debugName() in speedup.internal_result:
# the second input parameters is the output of the other
# nodes
return torch.div
else:
other = inputs[1].toIValue()
new_div = partial(torch.div, other=other)
return new_div
def softmax_python(node, speedup):
c_node = node.key_node
inputs = list(c_node.inputs())
dim = inputs[1].toIValue()
new_softmax = partial(torch.softmax, dim=dim)
return new_softmax
def contiguous_python(node, speedup):
class contiguousModule(torch.nn.Module):
def forward(self, x):
return x.contiguous()
return contiguousModule()
def gelu_python(node, speedup):
return torch.nn.GELU()
def avgpool2d_python(node, speedup):
c_node = node.key_node
inputs = list(c_node.inputs())
kernel_size = translate_list(inputs[1], speedup)
stride = translate_list(inputs[2], speedup)
padding = translate_list(inputs[3], speedup)
new_avgpool = partial(torch.nn.functional.avg_pool2d,
kernel_size=kernel_size, stride=stride, padding=padding)
return new_avgpool
def adaptive_avgpool_python(node, speedup):
c_node = node.key_node
inputs = list(c_node.inputs())
output_size = translate_list(inputs[1], speedup)
new_avgpool = torch.nn.AdaptiveAvgPool2d(output_size)
return new_avgpool
def tupleunpack_python(node, speedup):
# Note: tuple unpack should only exists at the
# the end of the model, and is no need to replace/propagate mask
return None
def num2tensor_python(node, speedup):
return torch.nn.Identity()
def exp_python(node, speedup):
return torch.exp
def squeeze_python(node, speedup):
c_node = node.key_node
inputs = list(c_node.inputs())
dim = None
if len(inputs) > 1:
dim = parse_constant(inputs[1], speedup)
new_squeeze = partial(torch.squeeze, dim=dim)
return new_squeeze
##########################################################
# Split Line
# Following module/functions cannot be translated into a
# single function, so we use torch.nn.Module to wrap the
# the core function, and return the torch.nn.Module instead
##########################################################
def slice_python(node, speedup):
class SliceMoudle(torch.nn.Module):
def __init__(self, sliceobj):
super(SliceMoudle, self).__init__()
self.sliceobj = sliceobj
def forward(self, x, *args):
# args is for the slice dimension and indexes, however,
# we already get them from the cpp nodes. Note, though, we
# don't need the slice indexes any more, we cannot remove this
# parameter here, because, there may be multiple inputs passed from
# previous nodes such as aten::size
logger.info('Model has Slice operation, and the operand size=%s, Slice object:%s', str(
x.size()), str(self.sliceobj))
return x[self.sliceobj]
c_node = node.key_node
inputs = list(c_node.inputs())
slice_dim = parse_constant(inputs[1], speedup)
slice_start = parse_constant(inputs[2], speedup)
slice_end = parse_constant(inputs[3], speedup)
slice_step = parse_constant(inputs[4], speedup)
slice_obj = slice(slice_start, slice_end, slice_step)
slice_list = []
for _ in range(slice_dim):
slice_list.append(slice(None, None))
logger.info('Slice dim:%s, Slice obj:%s', str(slice_dim), str(slice_obj))
slice_list.append(slice_obj)
return SliceMoudle(tuple(slice_list))
def select_python(node, speedup):
class SelectModule(torch.nn.Module):
def __init__(self, dim, index):
super(SelectModule, self).__init__()
self.dim = dim
self.index = index
def forward(self, x):
return x.select(self.dim, self.index)
c_node = node.key_node
inputs = list(c_node.inputs())
dim = inputs[1].toIValue()
index = inputs[2].toIValue()
return SelectModule(dim, index)
def size_python(node, speedup):
# return None
class SizeMoudle(torch.nn.Module):
def __init__(self, sizedim):
super(SizeMoudle, self).__init__()
self.sizedim = sizedim
def forward(self, x):
return torch.as_tensor([x.size(self.sizedim)], dtype=torch.long)
# return torch.tensor(x.size(self.sizedim))
c_node = node.key_node
inputs = list(c_node.inputs())
size_dim = inputs[1].toIValue()
return SizeMoudle(size_dim)
def toint_python(node, speedup):
class ToIntModule(torch.nn.Module):
def forward(self, x):
return x.to(torch.int)
return ToIntModule()
def view_python(node, speedup):
class ViewModule(torch.nn.Module):
def __init__(self, shape):
super(ViewModule, self).__init__()
self.shape = shape
logger.info('View Module output size: %s', str(self.shape))
def forward(self, *args):
return args[0].view(self.shape)
c_node = node.key_node
inputs = list(c_node.inputs())
shape = translate_list(inputs[1], speedup)
return ViewModule(shape)
def reshape_python(node, speedup):
class ReshapeModule(torch.nn.Module):
def __init__(self, shape):
super(ReshapeModule, self).__init__()
self.shape = shape
logger.info('Reshape Module output size: %s', str(self.shape))
def forward(self, *args):
return args[0].view(self.shape)
c_node = node.key_node
inputs = list(c_node.inputs())
shape = translate_list(inputs[1], speedup)
return ReshapeModule(shape)
def permute_python(node, speedup):
class PermuteModule(torch.nn.Module):
def __init__(self, dimlist):
super(PermuteModule, self).__init__()
self.dimlist = dimlist
def forward(self, x):
return x.permute(self.dimlist)
c_node = node.key_node
inputs = list(c_node.inputs())
dim_list = translate_list(inputs[1], speedup)
return PermuteModule(dim_list)
def getattr_python(node, speedup):
"""
Note: Ops started with Prim:: is not taken as the key node,
so we directly pass the Cpp node into this funciton.
Parameters
----------
node: torch._C.Node
The cpp node of prim::Getattr
speedup: ModelSpeedup
The corresponding speedup object.
"""
class GetModule(torch.nn.Module):
def __init__(self, key):
super(GetModule, self).__init__()
self.key = key
def forward(self, obj):
logger.info('Get attribute: %s', self.key)
return getattr(obj, self.key)
# get the name of the attribute, for example
# prim::GetAttr[name="module_list"](%self.1)
assert node.kind() == 'prim::GetAttr'
pattern = '\[name=\"(.*?)\"\]'
key_words = re.findall(pattern, str(node))
assert len(key_words) == 1
return GetModule(key_words[0])
def upsample_bilinear2d_python(node, speedup):
class UpsampleModule(torch.nn.Module):
def __init__(self, size_list, scale_list):
super(UpsampleModule, self).__init__()
self.size_list = size_list
self.scale_list = scale_list
def forward(self, *args):
"""
The first input of args is the target tensor to upsample
, the following parameters is useless, because we already
get the size_list and the scale_list by parsing the cpp_nodes.
"""
return torch.nn.functional.upsample_bilinear(args[0],
size=self.size_list, scale_factor=self.scale_list)
c_node = node.key_node
inputs = list(c_node.inputs())
size_list_node = inputs[1].node()
scale_list_node = inputs[3].node()
size_list = None
scale_list = None
if size_list_node.kind() == 'prim::ListConstruct':
size_list = translate_list(inputs[1], speedup)
if scale_list_node.kind() == 'prim::ListConstruct':
scale_list = translate_list(inputs[3], speedup)
return UpsampleModule(size_list, scale_list)
def typeas_python(node, speedup):
"""
currently only support type_as float.
TODO: support more types in the type_as, need to figure out
how to get the scalar type from torch._C.TensorType.
"""
class TypeasModule(torch.nn.Module):
def __init__(self, dtype=torch.float):
self.example = torch.zeros(1, dtype=dtype)
def forward(self, x):
return x.type_as(self.example)
return TypeasModule()
def to_python(node, speedup):
# for the time being, only device parameters are supported
class ToModule(torch.nn.Module):
def __init__(self, device):
super(ToModule, self).__init__()
def forward(self, x):
return x.to(device)
c_node = node.key_node
inputs = list(c_node.inputs())
device = inputs[3].toIValue()
return ToModule(device)
def cat_python(node, speedup):
class CatModule(torch.nn.Module):
def __init__(self, cat_dim):
super(CatModule, self).__init__()
self.cat_dim = cat_dim
def forward(self, *args):
return torch.cat(args, dim=self.cat_dim)
c_node = node.key_node
inputs = list(c_node.inputs())
dim = inputs[1].toIValue()
return CatModule(dim)
trans_from_jit_to_python = {
'aten::add': add_python,
'aten::add_': add_python,
'aten::mul': mul_python,
'aten::mul_': mul_python,
'aten::relu': relu_python,
'aten::relu_': relu_inplace_python,
'aten::sigmoid': sigmoid_python,
'aten::sigmoid_': sigmoid_python,
# tanh behaives like relu
'aten::tanh': relu_python,
'aten::tanh_': relu_python,
'aten::flatten': flatten_python,
'aten::mean': mean_python,
'aten::dropout': dropout_python,
'aten::slice': slice_python,
'aten::select': select_python,
'aten::size': size_python,
'aten::t': transpose_python,
'aten::transpose': transpose2_python,
'aten::Int': toint_python,
'aten::view': view_python,
'aten::reshape': reshape_python,
'aten::permute': permute_python,
'aten::matmul': matmul_python,
'aten::div': div_python,
'aten::floor_divide': floor_div_python,
'aten::softmax': softmax_python,
'aten::contiguous': contiguous_python,
'aten::gelu': gelu_python,
'aten::cat': cat_python,
'aten::avg_pool2d': avgpool2d_python,
'aten::max_pool2d': avgpool2d_python,
'aten::adaptive_avg_pool2d': adaptive_avgpool_python,
'aten::to': to_python,
'aten::type_as': typeas_python,
'aten::upsample_bilinear2d': upsample_bilinear2d_python,
'aten::exp': exp_python,
'aten::squeeze': squeeze_python,
'prim::TupleUnpack': tupleunpack_python,
'prim::ListUnpack': tupleunpack_python,
'prim::NumToTensor': num2tensor_python,
'prim::GetAttr': getattr_python
}
def jit_to_python_function(node, speedup):
"""
Return a callable object to inference the mask according to the
node.op_type.
Parameters
---------
node: NodeGroup
The target node to inference the mask
speedup: ModelSpeedup
The speedup object of the target model.
Returns
------
func: callable object(nn.Module/function)
Return the translated function that used to inference the mask
, if current op_type is not supported, then we return None.
"""
logger.debug(
'Translate C function %s into its python version', node.op_type)
if node.op_type not in trans_from_jit_to_python:
logger.error(
'%s is not Supported! Please report an issue at https://github.com/microsoft/nni. Thanks~', node.op_type)
# return None to skip the mask inference for this node
return None
return trans_from_jit_to_python[node.op_type](node, speedup)
from .utils import *
\ No newline at end of file
......@@ -51,3 +51,25 @@ class CompressorSchema:
def validate(self, data):
self.compressor_schema.validate(data)
def validate_exclude_sparsity(data):
if not ('exclude' in data or 'sparsity' in data):
raise SchemaError('Either sparisty or exclude must be specified.')
return True
def validate_exclude_quant_types_quant_bits(data):
if not ('exclude' in data or ('quant_types' in data and 'quant_bits' in data)):
raise SchemaError('Either (quant_types and quant_bits) or exclude must be specified.')
return True
class PrunerSchema(CompressorSchema):
def _modify_schema(self, data_schema, model, logger):
data_schema = super()._modify_schema(data_schema, model, logger)
data_schema[0] = And(data_schema[0], lambda d: validate_exclude_sparsity(d))
return data_schema
class QuantizerSchema(CompressorSchema):
def _modify_schema(self, data_schema, model, logger):
data_schema = super()._modify_schema(data_schema, model, logger)
data_schema[0] = And(data_schema[0], lambda d: validate_exclude_quant_types_quant_bits(d))
return data_schema
......@@ -4,10 +4,10 @@ import os
import logging
import torch
import numpy as np
from .shape_dependency import ChannelDependency, GroupDependency, CatPaddingDependency, InputChannelDependency
from .shape_dependency import ChannelDependency, GroupDependency, InputChannelDependency
from .utils import get_module_by_name
# logging.basicConfig(level = logging.DEBUG)
_logger = logging.getLogger(__name__)
_logger = logging.getLogger('FixMaskConflict')
def fix_mask_conflict(masks, model=None, dummy_input=None, traced=None):
......@@ -21,7 +21,7 @@ def fix_mask_conflict(masks, model=None, dummy_input=None, traced=None):
A dict object that stores the masks or the path of the mask file
model : torch.nn.Module
model to fix the mask conflict
dummy_input : torch.Tensor
dummy_input : torch.Tensor/list of tensors/dict of tensors
input example to trace the model
traced : torch._C.torch.jit.TopLevelTracedModule
the traced model of the target model, is this parameter is not None,
......@@ -48,9 +48,7 @@ def fix_mask_conflict(masks, model=None, dummy_input=None, traced=None):
masks = fix_group_mask.fix_mask()
fix_channel_mask = ChannelMaskConflict(masks, model, dummy_input, traced)
masks = fix_channel_mask.fix_mask()
padding_cat_mask = CatMaskPadding(masks, model, dummy_input, traced)
masks = padding_cat_mask.fix_mask()
return masks, fix_channel_mask.conv_prune_dim
return masks
class MaskFix:
......@@ -78,70 +76,6 @@ class MaskFix:
torch.save(self.masks, path)
class CatMaskPadding(MaskFix):
def __init__(self, masks, model, dummy_input=None, traced=None):
"""
CatMaskPadding find the layers whose output tensor is passed
to the same cat operation. The cat operation concatnates the
masks of the input tensors as the output mask, so when some
of the input layers of the cat operation are not pruned, we still
need to pass the masks of these non-pruned layers(the mask are
all ones) to the cat operation to ensure the shape of the output
mask is right.
Parameters
----------
masks : dict
a dict object that stores the masks
model : torch.nn.Module
model to fix the mask conflict
dummy_input : torch.Tensor
input example to trace the model
traced : torch._C.torch.jit.TopLevelTracedModule
the traced model of the target model, is this parameter is not None,
we donnot use the model and dummpy_input to get the trace graph.
"""
super(CatMaskPadding, self).__init__(masks, model, dummy_input, traced)
def fix_mask(self):
cat_padding_depen = CatPaddingDependency(
self.model, self.dummy_input, self.traced)
name_to_module = {}
for name, module in self.model.named_modules():
name_to_module[name] = module
depen = cat_padding_depen.dependency_sets
for layers in depen:
device = None
count = 0
for layer in layers:
if layer in self.masks:
count += 1
if device is None:
device = self.masks[layer]['weight'].device
if count == 0:
# no layer is pruned
continue
elif count == len(layers):
# all the layers have been pruned
continue
# pad the mask for the non-pruned layers
for layer in layers:
if layer in self.masks:
continue
module = name_to_module[layer]
w_shape = module.weight.data.size()
w_mask = torch.ones(w_shape).to(device)
b_mask = None
if hasattr(module, 'bias') and module.bias is not None:
# module.bias may be None
b_shape = module.bias.data.size()
b_mask = torch.ones(b_shape).to(device)
self.masks[layer] = {'weight': w_mask, 'bias': b_mask}
return self.masks
class GroupMaskConflict(MaskFix):
def __init__(self, masks, model=None, dummy_input=None, traced=None):
"""
......@@ -172,9 +106,11 @@ class GroupMaskConflict(MaskFix):
group_depen = GroupDependency(
self.model, self.dummy_input, self.traced)
depens = group_depen.dependency
min_groups = group_depen.min_groups
_logger.info(depens)
for layername in depens:
group = depens[layername]
group_max = depens[layername]
group_min = min_groups[layername]
if layername not in self.masks:
# this layer not pruned
continue
......@@ -187,29 +123,43 @@ class GroupMaskConflict(MaskFix):
# In fine-grained pruning, skip this layer
_logger.info('Layers %s using fine-grained pruning', layername)
continue
assert shape[0] % group == 0
assert shape[0] % group_max == 0
# Find the number of masked filter for each group (mini_masked).
# Because we have to keep the pruned filter can still
# be divided into the same number of groups, so we only can
# prune mini_masked filters for each group.
step = shape[0] / group
step = shape[0] / group_max
group_masked = []
for i in range(group):
for i in range(group_max):
_start = step * i
_end = step * (i + 1)
_tmp_list = list(
filter(lambda x: _start <= x and x < _end, all_zeros))
group_masked.append(_tmp_list)
mini_masked = min([len(x) for x in group_masked])
need_unmask = set()
for gm in group_masked:
for i in range(mini_masked, len(gm)):
# To keep the output channel number still being divisible to
# groups, we set the masks of following filters to be zero.
pos = gm[i]
self.masks[layername]['weight'][pos] = torch.ones(
shape[1:])
if 'bias' in self.masks[layername] and self.masks[layername]['bias'] is not None:
self.masks[layername]['bias'][pos] = 1
need_unmask.add(pos)
step = shape[0] / group_min
for i in range(group_min):
_start = step * i
_end = step * (i+1)
_tmp_list = list(
filter(lambda x: _start <= x and x < _end, all_zeros))
if len(_tmp_list) == step:
# if the whole group is removed, then we don't have to unmask for
# the filters in this group
for pos in _tmp_list:
if pos in need_unmask:
need_unmask.remove(pos)
for pos in need_unmask:
self.masks[layername]['weight'][pos] = torch.ones(shape[1:])
if hasattr(self.masks[layername], 'bias'):
self.masks[layername]['bias'][pos] = 1
return self.masks
......@@ -234,9 +184,14 @@ class ChannelMaskConflict(MaskFix):
super(ChannelMaskConflict, self).__init__(
masks, model, dummy_input, traced)
self.conv_prune_dim = detect_mask_prune_dim(masks, model)
_logger.info('detected conv prune dim: %s', self.conv_prune_dim)
_logger.info('Dectected conv prune dim" %d', self.conv_prune_dim)
def fix_mask(self):
"""
Fix the mask conflict before the mask inference for the layers that
has shape dependencies. This function should be called before the
mask inference of the 'speedup' module.
"""
"""
Fix the mask conflict before the mask inference for the layers that
has shape dependencies. This function should be called before the
......@@ -274,7 +229,12 @@ class ChannelMaskConflict(MaskFix):
if (channel_mask.sum() * (mask.numel() / mask.shape[self.conv_prune_dim])).item() != (mask > 0).sum().item():
fine_grained = True
elif type(m).__name__ == 'Linear':
channel_masks.append((mask.abs().sum(0) != 0).int())
if self.conv_prune_dim == 1:
channel_masks.append(
(mask.abs().sum(0) != 0).int())
else:
channel_masks.append(
(mask.abs().sum(1) != 0).int())
elif type(m).__name__ == 'BatchNorm2d':
channel_masks.append(mask.int())
elif type(m).__name__ == 'ConvTranspose2d':
......@@ -293,9 +253,7 @@ class ChannelMaskConflict(MaskFix):
# no mask means not pruned, equivlent to full masks
channel_masks.append(None)
if fine_grained:
_logger.info(
'fine-grained mask detected, skip solving conflict for this set: %s', dset)
continue
_logger.info("Fine-grianed mask detected")
if all(x is None for x in channel_masks):
continue
num_channels_list = [len(x)
......@@ -306,7 +264,8 @@ class ChannelMaskConflict(MaskFix):
for i, dim_mask in enumerate(channel_masks):
if dim_mask is None:
channel_masks[i] = torch.ones(num_channels).int().to(device)
channel_masks[i] = torch.ones(
num_channels).int().to(device)
# merge masks with 'or'
merged_channel_mask = channel_masks[0].clone()
......@@ -329,19 +288,22 @@ class ChannelMaskConflict(MaskFix):
else:
new_mask[:, merged_index, :, :] = 1.
elif type(m).__name__ == 'Linear':
new_mask[:, merged_index] = 1.
if self.conv_prune_dim == 0:
new_mask[merged_index, :] = 1
elif self.conv_prune_dim == 1:
new_mask[:, merged_index] = 1.
elif type(m).__name__ == 'BatchNorm2d':
new_mask = merged_channel_mask.type_as(orig_mask)
else:
raise RuntimeError(
f'unsupported module type: {type(m).__name__}')
self.masks[name]['weight'] = new_mask
if 'bias' in self.masks[name] and self.masks[name]['bias'] is not None:
if type(m).__name__ == 'Conv2d':
assert self.conv_prune_dim == 0
self.masks[name]['bias'] = merged_channel_mask.type_as(
self.masks[name]['bias'])
if self.conv_prune_dim == 0:
self.masks[name]['bias'] = merged_channel_mask.type_as(
self.masks[name]['bias'])
return self.masks
......@@ -349,14 +311,12 @@ class ChannelMaskConflict(MaskFix):
def detect_mask_prune_dim(masks, model):
"""
Detect how the masks of convolutional layers are pruned.
Parameters
----------
masks: dict
A dict object that stores the masks.
model: nn.Module
Model object which the mask can be applied on.
Returns:
-------
How the masks of convolutional layers are pruned, this depends on pruning algorithms, it should
......
......@@ -3,18 +3,34 @@
import csv
import logging
import numpy as np
__all__ = ['ChannelDependency', 'GroupDependency',
'CatPaddingDependency', 'InputChannelDependency']
__all__ = ['ChannelDependency', 'GroupDependency', 'InputChannelDependency']
CONV_TYPE = 'aten::_convolution'
ADD_TYPES = ['aten::add', 'aten::add_']
MUL_TYPES = ['aten::mul', 'atem::mul_']
CAT_TYPE = 'aten::cat'
logger = logging.getLogger('Shape_Dependency')
RESHAPE_OPS = [CAT_TYPE, 'aten::view',
'aten::reshape', 'aten::flatten', 'aten::mean']
def lcm_list(L):
lcm = 1
for i in L:
lcm = np.lcm(lcm, i)
return lcm
def gcd_list(L):
gcd = L[0]
for i in L:
gcd = np.gcd(gcd, i)
return gcd
class Dependency:
def __init__(self, model=None, dummy_input=None, traced_model=None):
"""
......@@ -38,6 +54,35 @@ class Dependency:
raise NotImplementedError
def reshape_break_channel_dependency(op_node):
"""
The reshape operations such as (reshape, view, flatten) may break
the channel dependency. We need to check the input parameters of
these reshape operations to check if this reshape node will break
the channel dependency. However, it's complicated to analyze the the input
parameters for each reshape function and infer if it will break the channel
dependency. So currently, we just check if the input channel and the output
channel is the same, if so, then we can say the original reshape function
doesn't want to change the number of the channels, which means the channel
dependency is not broken. In contrast, the original reshap operation wants
to change the number of channels, so it breaks the channel dependency.
Parameters
----------
opnode: NodePyOP
A Op node of the graph.
Returns
-------
bool
If this operation will break the channel dependency.
"""
in_shape = op_node.auxiliary['in_shape']
out_shape = op_node.auxiliary['out_shape']
in_channel = in_shape[1]
out_channel = out_shape[1]
return in_channel != out_channel
class ChannelDependency(Dependency):
def __init__(self, model=None, dummy_input=None, traced_model=None):
"""
......@@ -80,6 +125,9 @@ class ChannelDependency(Dependency):
# find the first met conv
parent_layers.append(curnode.name)
continue
elif curnode.op_type in RESHAPE_OPS:
if reshape_break_channel_dependency(curnode):
continue
parents = self.graph.find_predecessors(curnode.unique_name)
parents = [self.graph.name_to_node[name] for name in parents]
for parent in parents:
......@@ -176,7 +224,7 @@ class ChannelDependency(Dependency):
d_sets = []
visited = set()
for node in self.graph.nodes_py.nodes_op:
if node.op_type != 'Conv2d' or node in visited:
if (node.op_type != 'Conv2d' and node.op_type != 'Linear') or node in visited:
continue
tmp_set = set()
if node.name not in self.dependency:
......@@ -190,35 +238,6 @@ class ChannelDependency(Dependency):
return d_sets
def reshape_break_channel_dependency(op_node):
"""
The reshape operations such as (reshape, view, flatten) may break
the channel dependency. We need to check the input parameters of
these reshape operations to check if this reshape node will break
the channel dependency. However, it's complicated to analyze the the input
parameters for each reshape function and infer if it will break the channel
dependency. So currently, we just check if the input channel and the output
channel is the same, if so, then we can say the original reshape function
doesn't want to change the number of the channels, which means the channel
dependency is not broken. In contrast, the original reshap operation wants
to change the number of channels, so it breaks the channel dependency.
Parameters
----------
opnode: NodePyOP
A Op node of the graph.
Returns
-------
bool
If this operation will break the channel dependency.
"""
in_shape = op_node.auxiliary['in_shape']
out_shape = op_node.auxiliary['out_shape']
in_channel = in_shape[1]
out_channel = out_shape[1]
return in_channel != out_channel
class InputChannelDependency(ChannelDependency):
"""
Some pruners may prune the input channel of the convolutional
......@@ -295,67 +314,6 @@ class InputChannelDependency(ChannelDependency):
self.dependency[layer] = dependency_set
class CatPaddingDependency(ChannelDependency):
def __init__(self, model=None, dummy_input=None, traced_model=None):
super(CatPaddingDependency, self).__init__(
model, dummy_input, traced_model)
def build_dependency(self):
"""
Build the cat padding dependencies.
If the output features of several layers are stitched together
by cat operation, then these layers have cat padding dependencies.
This is because when inferring the cat mask, we need all the input
masks for the cat operation. At this time we need to know the source
of all input vectors of a cat operation.
"""
for node in self.graph.nodes_py.nodes_op:
parent_layers = []
if node.op_type == CAT_TYPE:
parent_layers = self._get_parent_layers(node)
dependency_set = set(parent_layers)
# merge the dependencies
for parent in parent_layers:
if parent in self.dependency:
dependency_set.update(self.dependency[parent])
# save the dependencies
for _node in dependency_set:
self.dependency[_node] = dependency_set
@property
def dependency_sets(self):
d_sets = []
visited = set()
for nodename in self.dependency:
if nodename in visited:
continue
d_sets.append(self.dependency[nodename])
return d_sets
def export(self, filepath):
"""
Export the dependencies into a file.
In the output file, each line contains a set of layers
whose output features are stitched together by the cat
operation.
output example:
Dependency Set, Layers
set1, Conv1, Conv2
set2, Conv3, Conv4
"""
header = ['Dependency Set', 'Layers']
setid = 0
with open(filepath, 'w') as csvf:
csv_w = csv.writer(csvf, delimiter=',')
csv_w.writerow(header)
for layers in self.dependency_sets:
setid += 1
row = ['Set %d' % setid]
row.extend(list(layers))
csv_w.writerow(row)
class GroupDependency(Dependency):
def __init__(self, model=None, dummy_input=None, traced_model=None):
"""
......@@ -372,6 +330,7 @@ class GroupDependency(Dependency):
if we alreay has the traced graph of the target model, we donnot
need to trace the model again.
"""
self.min_groups = {}
super(GroupDependency, self).__init__(model, dummy_input, traced_model)
def _get_parent_convs(self, node):
......@@ -451,27 +410,33 @@ class GroupDependency(Dependency):
key: the name of conv layers, value: the minimum value that the number of
filters should be divisible to.
"""
self.groups = {}
for node in self.graph.nodes_py.nodes_op:
if node.op_type == 'Conv2d' or node.op_type == 'ConvTranspose2d':
group = self._get_conv_groups(node)
if node.name in self.dependency:
if node.name in self.groups:
# the conv layer whose group is larger than 1 will require that
# it's number of output channel to be divisible by the number of group.
self.dependency[node.name] = max(
self.dependency[node.name], group)
self.groups[node.name].append(group)
else:
self.dependency[node.name] = group
self.groups[node.name] = [group]
if group > 1:
# for the conv layer whose group is larger than 1, it will require the number
# of output channels of their parent conv layer to be divisible by group.
parent_convs = self._get_parent_convs(node)
for parent in parent_convs:
if parent in self.dependency:
self.dependency[parent] = max(
self.dependency[parent], group)
if parent in self.groups:
self.groups[parent].append(group)
else:
self.dependency[parent] = group
self.groups[parent] = [group]
for name in self.groups:
self.dependency[name] = lcm_list(self.groups[name])
if min(self.groups[name]) == gcd_list(self.groups[name]):
self.min_groups[name] = min(self.groups[name])
else:
self.min_groups[name] = 1
return self.dependency
def export(self, filepath):
......@@ -501,3 +466,110 @@ class GroupDependency(Dependency):
@property
def dependency_sets(self):
return self.dependency
class ReshapeDependency(Dependency):
def __init__(self, model=None, dummy_input=None, traced_model=None):
"""
Some model may have the view/reshape functions, such functions may have fixed parameters
and cannot be replaced at all. Therefore, these functions may have some constraints on
their input shapes. In this class, we find the direct input conv/linear layers of these
reshape functions. If you get the shape conflict when run the forward inference on the
speeduped model, please try remove these layers from the pruner config list and try again.
Parameters
----------
model : torch.nn.Module
The model to be analyzed.
data : torch.Tensor
The example input data to trace the network architecture.
traced_model : torch._C.Graph
if we alreay has the traced graph of the target model, we donnot
need to trace the model again.
"""
super(ReshapeDependency, self).__init__(
model, dummy_input, traced_model)
def _get_parent_layers(self, node):
"""
Find the nearest father conv layers for the target node.
Parameters
---------
node : torch._C.Node
target node.
Returns
-------
parent_layers: list
nearest father conv/linear layers for the target worknode.
"""
parent_layers = []
queue = []
queue.append(node)
while queue:
curnode = queue.pop(0)
if curnode.op_type == 'Conv2d' or curnode.op_type == 'Linear' or curnode.op_type == 'ConvTranspose2d':
# find the first met conv
parent_layers.append(curnode.name)
continue
parents = self.graph.find_predecessors(curnode.unique_name)
parents = [self.graph.name_to_node[name] for name in parents]
for parent in parents:
queue.append(parent)
return parent_layers
def build_dependency(self):
"""
Build the channel dependency for the conv layers
in the model.
"""
# unpack the tuple/list manually before analyze the
# channel dependency
self.graph.unpack_manually()
for node in self.graph.nodes_py.nodes_op:
parent_layers = []
# find the node that contains aten::add
# or aten::cat operations
if node.op_type in ['aten::view', 'aten::reshape']:
logger.info('Detect reshape-like functions: %s', node.op_type)
parent_layers = self._get_parent_layers(node)
print('Parent layers', parent_layers)
self.dependency[node.unique_name] = parent_layers
def export(self, filepath):
"""
export the reshape dependencies as a csv file.
Output example:
Reshape OP, Dependent Layers
model.view.1,layer1.1.conv2,layer1.0.conv2,conv1
model.mean.1,layer1.0.conv1
model.reshape.1,layer1.1.conv1
"""
header = ['Reshape OP', 'Dependent Layers']
with open(filepath, 'w') as csvf:
csv_w = csv.writer(csvf, delimiter=',')
csv_w.writerow(header)
for reshape_op in self.dependency:
row = [reshape_op].extend(self.dependency[reshape_op])
csv_w.writerow(row)
@property
def dependency_sets(self):
"""
Get the list of the dependency set.
Returns
-------
dependency_sets : list
list of the dependency sets. For example,
[set(['conv1', 'conv2']), set(['conv3', 'conv4'])]
"""
d_sets = []
for reshape_node in self.dependency:
d_sets.extend(self.dependency[reshape_node])
d_sets = list(set(d_sets))
return d_sets
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import torch
from .shape_dependency import ReshapeDependency
torch_float_dtype = [torch.float, torch.float16, torch.float32, torch.float64, torch.half, torch.double]
torch_integer_dtype = [torch.uint8, torch.int16, torch.short, torch.int32, torch.long, torch.bool]
def get_module_by_name(model, module_name):
"""
......@@ -28,3 +33,57 @@ def get_module_by_name(model, module_name):
return model, leaf_module
else:
return None, None
def rand_like_with_shape(shape, ori_t):
"""
Return a new random tensor like the original
tensor.
"""
assert isinstance(ori_t, torch.Tensor)
device = ori_t.device
dtype = ori_t.dtype
require_grad = ori_t.requires_grad
lower_bound = torch.min(ori_t)
higher_bound = torch.max(ori_t)
if dtype in [torch.uint8, torch.int16, torch.short, torch.int16, torch.long, torch.bool]:
return torch.randint(lower_bound, higher_bound+1, shape, dtype=dtype, device=device)
else:
return torch.rand(shape, dtype=dtype, device=device, requires_grad=require_grad)
def randomize_tensor(tensor, start=1, end=100):
"""
Randomize the target tensor according to the given
range.
"""
assert isinstance(tensor, torch.Tensor)
if tensor.dtype in torch_integer_dtype:
# integer tensor can only be randomized by the torch.randint
# torch.randint(int(start), int(end), tensor.size(), out=tensor.data, dtype=tensor.dtype)
pass
else:
# we can use nn.init.uniform_ to randomize this tensor
# Note: the tensor that with integer type cannot be randomize
# with nn.init.uniform_
torch.nn.init.uniform_(tensor.data, start, end)
def not_safe_to_prune(model, dummy_input):
"""
Get the layers that are not safe to prune(may bring the shape conflict).
For example, if the output tensor of a conv layer is directly followed by
a shape-dependent function(such as reshape/view), then this conv layer
may be not safe to be pruned. Pruning may change the output shape of
this conv layer and result in shape problems. This function find all the
layers that directly followed by the shape-dependent functions(view, reshape, etc).
If you run the inference after the speedup and run into a shape related error,
please exclude the layers returned by this function and try again.
Parameters
----------
model: torch.nn.Module
The target model to prune.
dummy_input: torch.Tensor/list of torch.Tensor/tuple of Tensor
"""
reshape_dset = ReshapeDependency(model, dummy_input)
return reshape_dset.dependency_sets
\ No newline at end of file
......@@ -65,6 +65,7 @@ class ExperimentConfig(ConfigBase):
trial_gpu_number: Optional[int] = None # TODO: in openpai cannot be None
max_experiment_duration: Optional[str] = None
max_trial_number: Optional[int] = None
max_trial_duration: Optional[int] = None
nni_manager_ip: Optional[str] = None
use_annotation: bool = False
debug: bool = False
......@@ -153,6 +154,7 @@ _validation_rules = {
'trial_gpu_number': lambda value: value >= 0,
'max_experiment_duration': lambda value: util.parse_time(value) > 0,
'max_trial_number': lambda value: value > 0,
'max_trial_duration': lambda value: util.parse_time(value) > 0,
'log_level': lambda value: value in ["trace", "debug", "info", "warning", "error", "fatal"],
'tuner_gpu_indices': lambda value: all(i >= 0 for i in value) and len(value) == len(set(value)),
'training_service': lambda value: (type(value) is not TrainingServiceConfig, 'cannot be abstract base class')
......
......@@ -27,6 +27,9 @@ def to_v2(v1) -> ExperimentConfig:
if isinstance(v2.max_experiment_duration, (int, float)):
v2.max_experiment_duration = str(v2.max_experiment_duration) + 's'
_move_field(v1, v2, 'maxTrialNum', 'max_trial_number')
_move_field(v1, v2, 'maxTrialDuration', 'max_trial_duration')
if isinstance(v2.max_trial_duration, (int, float)):
v2.max_trial_duration = str(v2.max_trial_duration) + 's'
_move_field(v1, v2, 'searchSpacePath', 'search_space_file')
assert not v1.pop('multiPhase', None), 'Multi-phase is no longer supported'
_deprecate(v1, v2, 'multiThread')
......
......@@ -8,7 +8,7 @@ import socket
from subprocess import Popen
import sys
import time
from typing import Optional, Tuple
from typing import Optional, Tuple, List, Any
import colorama
......@@ -43,7 +43,7 @@ def start_experiment(exp_id: str, config: ExperimentConfig, port: int, debug: bo
_check_rest_server(port)
platform = 'hybrid' if isinstance(config.training_service, list) else config.training_service.platform
_save_experiment_information(exp_id, port, start_time, platform,
config.experiment_name, proc.pid, str(config.experiment_working_directory))
config.experiment_name, proc.pid, str(config.experiment_working_directory), [])
_logger.info('Setting up...')
rest.post(port, '/experiment', config.json())
return proc
......@@ -78,7 +78,7 @@ def start_experiment_retiarii(exp_id: str, config: ExperimentConfig, port: int,
_check_rest_server(port)
platform = 'hybrid' if isinstance(config.training_service, list) else config.training_service.platform
_save_experiment_information(exp_id, port, start_time, platform,
config.experiment_name, proc.pid, config.experiment_working_directory)
config.experiment_name, proc.pid, config.experiment_working_directory, ['retiarii'])
_logger.info('Setting up...')
rest.post(port, '/experiment', config.json())
return proc, pipe
......@@ -156,9 +156,10 @@ def _check_rest_server(port: int, retry: int = 3) -> None:
rest.get(port, '/check-status')
def _save_experiment_information(experiment_id: str, port: int, start_time: int, platform: str, name: str, pid: int, logDir: str) -> None:
def _save_experiment_information(experiment_id: str, port: int, start_time: int, platform: str,
name: str, pid: int, logDir: str, tag: List[Any]) -> None:
experiments_config = Experiments()
experiments_config.add_experiment(experiment_id, port, start_time, platform, name, pid=pid, logDir=logDir)
experiments_config.add_experiment(experiment_id, port, start_time, platform, name, pid=pid, logDir=logDir, tag=tag)
def get_stopped_experiment_config(exp_id: str, mode: str) -> None:
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import os
import warnings
from typing import Dict, Union, Optional, List
from pathlib import Path
from typing import Dict, NoReturn, Union, Optional, List, Type
import pytorch_lightning as pl
import torch.nn as nn
......@@ -18,7 +20,13 @@ __all__ = ['LightningModule', 'Trainer', 'DataLoader', 'Lightning', 'Classificat
class LightningModule(pl.LightningModule):
def set_model(self, model):
"""
Basic wrapper of generated model.
Lightning modules used in NNI should inherit this class.
"""
def set_model(self, model: Union[Type[nn.Module], nn.Module]) -> NoReturn:
if isinstance(model, type):
self.model = model()
else:
......@@ -112,13 +120,23 @@ class _SupervisedLearningModule(LightningModule):
def __init__(self, criterion: nn.Module, metrics: Dict[str, pl.metrics.Metric],
learning_rate: float = 0.001,
weight_decay: float = 0.,
optimizer: optim.Optimizer = optim.Adam):
optimizer: optim.Optimizer = optim.Adam,
export_onnx: Union[Path, str, bool, None] = None):
super().__init__()
self.save_hyperparameters('criterion', 'optimizer', 'learning_rate', 'weight_decay')
self.criterion = criterion()
self.optimizer = optimizer
self.metrics = nn.ModuleDict({name: cls() for name, cls in metrics.items()})
if export_onnx is None or export_onnx is True:
self.export_onnx = Path(os.environ.get('NNI_OUTPUT_DIR', '.')) / 'model.onnx'
self.export_onnx.parent.mkdir(exist_ok=True)
elif export_onnx:
self.export_onnx = Path(export_onnx)
else:
self.export_onnx = None
self._already_exported = False
def forward(self, x):
y_hat = self.model(x)
return y_hat
......@@ -135,6 +153,11 @@ class _SupervisedLearningModule(LightningModule):
def validation_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
if not self._already_exported:
self.to_onnx(self.export_onnx, x, export_params=True)
self._already_exported = True
self.log('val_loss', self.criterion(y_hat, y), prog_bar=True)
for name, metric in self.metrics.items():
self.log('val_' + name, metric(y_hat, y), prog_bar=True)
......@@ -152,9 +175,8 @@ class _SupervisedLearningModule(LightningModule):
def on_validation_epoch_end(self):
nni.report_intermediate_result(self._get_validation_metrics())
def teardown(self, stage):
if stage == 'fit':
nni.report_final_result(self._get_validation_metrics())
def on_fit_end(self):
nni.report_final_result(self._get_validation_metrics())
def _get_validation_metrics(self):
if len(self.metrics) == 1:
......@@ -175,9 +197,11 @@ class _ClassificationModule(_SupervisedLearningModule):
def __init__(self, criterion: nn.Module = nn.CrossEntropyLoss,
learning_rate: float = 0.001,
weight_decay: float = 0.,
optimizer: optim.Optimizer = optim.Adam):
optimizer: optim.Optimizer = optim.Adam,
export_onnx: bool = True):
super().__init__(criterion, {'acc': _AccuracyWithLogits},
learning_rate=learning_rate, weight_decay=weight_decay, optimizer=optimizer)
learning_rate=learning_rate, weight_decay=weight_decay, optimizer=optimizer,
export_onnx=export_onnx)
class Classification(Lightning):
......@@ -200,6 +224,8 @@ class Classification(Lightning):
val_dataloaders : DataLoader or List of DataLoader
Used in ``trainer.fit()``. Either a single PyTorch Dataloader or a list of them, specifying validation samples.
If the ``lightning_module`` has a predefined val_dataloaders method this will be skipped.
export_onnx : bool
If true, model will be exported to ``model.onnx`` before training starts. default true
trainer_kwargs : dict
Optional keyword arguments passed to trainer. See
`Lightning documentation <https://pytorch-lightning.readthedocs.io/en/stable/common/trainer.html>`__ for details.
......@@ -211,9 +237,10 @@ class Classification(Lightning):
optimizer: optim.Optimizer = optim.Adam,
train_dataloader: Optional[DataLoader] = None,
val_dataloaders: Union[DataLoader, List[DataLoader], None] = None,
export_onnx: bool = True,
**trainer_kwargs):
module = _ClassificationModule(criterion=criterion, learning_rate=learning_rate,
weight_decay=weight_decay, optimizer=optimizer)
weight_decay=weight_decay, optimizer=optimizer, export_onnx=export_onnx)
super().__init__(module, Trainer(**trainer_kwargs),
train_dataloader=train_dataloader, val_dataloaders=val_dataloaders)
......@@ -223,9 +250,11 @@ class _RegressionModule(_SupervisedLearningModule):
def __init__(self, criterion: nn.Module = nn.MSELoss,
learning_rate: float = 0.001,
weight_decay: float = 0.,
optimizer: optim.Optimizer = optim.Adam):
optimizer: optim.Optimizer = optim.Adam,
export_onnx: bool = True):
super().__init__(criterion, {'mse': pl.metrics.MeanSquaredError},
learning_rate=learning_rate, weight_decay=weight_decay, optimizer=optimizer)
learning_rate=learning_rate, weight_decay=weight_decay, optimizer=optimizer,
export_onnx=export_onnx)
class Regression(Lightning):
......@@ -248,6 +277,8 @@ class Regression(Lightning):
val_dataloaders : DataLoader or List of DataLoader
Used in ``trainer.fit()``. Either a single PyTorch Dataloader or a list of them, specifying validation samples.
If the ``lightning_module`` has a predefined val_dataloaders method this will be skipped.
export_onnx : bool
If true, model will be exported to ``model.onnx`` before training starts. default: true
trainer_kwargs : dict
Optional keyword arguments passed to trainer. See
`Lightning documentation <https://pytorch-lightning.readthedocs.io/en/stable/common/trainer.html>`__ for details.
......@@ -259,8 +290,9 @@ class Regression(Lightning):
optimizer: optim.Optimizer = optim.Adam,
train_dataloader: Optional[DataLoader] = None,
val_dataloaders: Union[DataLoader, List[DataLoader], None] = None,
export_onnx: bool = True,
**trainer_kwargs):
module = _RegressionModule(criterion=criterion, learning_rate=learning_rate,
weight_decay=weight_decay, optimizer=optimizer)
weight_decay=weight_decay, optimizer=optimizer, export_onnx=export_onnx)
super().__init__(module, Trainer(**trainer_kwargs),
train_dataloader=train_dataloader, val_dataloaders=val_dataloaders)
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from typing import (Any, Iterable, List, Optional)
from typing import (Any, Iterable, List, Optional, Tuple)
from .graph import Model, Mutation, ModelStatus
__all__ = ['Sampler', 'Mutator']
__all__ = ['Sampler', 'Mutator', 'InvalidMutation']
Choice = Any
......@@ -77,7 +77,7 @@ class Mutator:
self._cur_choice_idx = None
return copy
def dry_run(self, model: Model) -> List[List[Choice]]:
def dry_run(self, model: Model) -> Tuple[List[List[Choice]], Model]:
"""
Dry run mutator on a model to collect choice candidates.
If you invoke this method multiple times on same or different models,
......@@ -115,3 +115,7 @@ class _RecorderSampler(Sampler):
def choice(self, candidates: List[Choice], *args) -> Choice:
self.recorded_candidates.append(candidates)
return candidates[0]
class InvalidMutation(Exception):
pass
from .api import *
from .component import *
from .nn import *
from .hypermodule import *
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment