Unverified Commit aa316742 authored by SparkSnail's avatar SparkSnail Committed by GitHub
Browse files

Merge pull request #233 from microsoft/master

merge master
parents 3fe117f0 24fa4619
...@@ -6,3 +6,4 @@ from .pruners import * ...@@ -6,3 +6,4 @@ from .pruners import *
from .weight_rank_filter_pruners import * from .weight_rank_filter_pruners import *
from .activation_rank_filter_pruners import * from .activation_rank_filter_pruners import *
from .quantizers import * from .quantizers import *
from .apply_compression import apply_compression_results
...@@ -32,7 +32,7 @@ class ActivationRankFilterPruner(Pruner): ...@@ -32,7 +32,7 @@ class ActivationRankFilterPruner(Pruner):
""" """
super().__init__(model, config_list) super().__init__(model, config_list)
self.mask_calculated_ops = set() self.register_buffer("if_calculated", torch.tensor(0)) # pylint: disable=not-callable
self.statistics_batch_num = statistics_batch_num self.statistics_batch_num = statistics_batch_num
self.collected_activation = {} self.collected_activation = {}
self.hooks = {} self.hooks = {}
...@@ -48,22 +48,29 @@ class ActivationRankFilterPruner(Pruner): ...@@ -48,22 +48,29 @@ class ActivationRankFilterPruner(Pruner):
""" """
Compress the model, register a hook for collecting activations. Compress the model, register a hook for collecting activations.
""" """
if self.modules_wrapper is not None:
# already compressed
return self.bound_model
else:
self.modules_wrapper = []
modules_to_compress = self.detect_modules_to_compress() modules_to_compress = self.detect_modules_to_compress()
for layer, config in modules_to_compress: for layer, config in modules_to_compress:
self._instrument_layer(layer, config) wrapper = self._wrap_modules(layer, config)
self.modules_wrapper.append(wrapper)
self.collected_activation[layer.name] = [] self.collected_activation[layer.name] = []
def _hook(module_, input_, output, name=layer.name): def _hook(module_, input_, output, name=layer.name):
if len(self.collected_activation[name]) < self.statistics_batch_num: if len(self.collected_activation[name]) < self.statistics_batch_num:
self.collected_activation[name].append(self.activation(output.detach().cpu())) self.collected_activation[name].append(self.activation(output.detach().cpu()))
layer.module.register_forward_hook(_hook) wrapper.module.register_forward_hook(_hook)
self._wrap_model()
return self.bound_model return self.bound_model
def get_mask(self, base_mask, activations, num_prune): def get_mask(self, base_mask, activations, num_prune):
raise NotImplementedError('{} get_mask is not implemented'.format(self.__class__.__name__)) raise NotImplementedError('{} get_mask is not implemented'.format(self.__class__.__name__))
def calc_mask(self, layer, config): def calc_mask(self, layer, config, **kwargs):
""" """
Calculate the mask of given layer. Calculate the mask of given layer.
Filters with the smallest importance criterion which is calculated from the activation are masked. Filters with the smallest importance criterion which is calculated from the activation are masked.
...@@ -82,14 +89,13 @@ class ActivationRankFilterPruner(Pruner): ...@@ -82,14 +89,13 @@ class ActivationRankFilterPruner(Pruner):
""" """
weight = layer.module.weight.data weight = layer.module.weight.data
op_name = layer.name
op_type = layer.type op_type = layer.type
assert 0 <= config.get('sparsity') < 1, "sparsity must in the range [0, 1)" assert 0 <= config.get('sparsity') < 1, "sparsity must in the range [0, 1)"
assert op_type in ['Conv2d'], "only support Conv2d" assert op_type in ['Conv2d'], "only support Conv2d"
assert op_type in config.get('op_types') assert op_type in config.get('op_types')
if op_name in self.mask_calculated_ops: if_calculated = kwargs["if_calculated"]
assert op_name in self.mask_dict if if_calculated:
return self.mask_dict.get(op_name) return None
mask_weight = torch.ones(weight.size()).type_as(weight).detach() mask_weight = torch.ones(weight.size()).type_as(weight).detach()
if hasattr(layer.module, 'bias') and layer.module.bias is not None: if hasattr(layer.module, 'bias') and layer.module.bias is not None:
mask_bias = torch.ones(layer.module.bias.size()).type_as(layer.module.bias).detach() mask_bias = torch.ones(layer.module.bias.size()).type_as(layer.module.bias).detach()
...@@ -104,8 +110,7 @@ class ActivationRankFilterPruner(Pruner): ...@@ -104,8 +110,7 @@ class ActivationRankFilterPruner(Pruner):
mask = self.get_mask(mask, self.collected_activation[layer.name], num_prune) mask = self.get_mask(mask, self.collected_activation[layer.name], num_prune)
finally: finally:
if len(self.collected_activation[layer.name]) == self.statistics_batch_num: if len(self.collected_activation[layer.name]) == self.statistics_batch_num:
self.mask_dict.update({op_name: mask}) if_calculated.copy_(torch.tensor(1)) # pylint: disable=not-callable
self.mask_calculated_ops.add(op_name)
return mask return mask
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import logging
import torch
from .compressor import Pruner
logger = logging.getLogger('torch apply compression')
def apply_compression_results(model, masks_file):
"""
Apply the masks from ```masks_file``` to the model
Parameters
----------
model : torch.nn.module
The model to be compressed
masks_file : str
The path of the mask file
"""
apply_comp = ApplyCompression(model, masks_file)
apply_comp.compress()
class ApplyCompression(Pruner):
"""
This class is not to generate masks, but applying existing masks
"""
def __init__(self, model, masks_file):
"""
Parameters
----------
model : torch.nn.module
Model to be masked
masks_file : str
The path of user provided mask file
"""
self.bound_model = model
self.masks = torch.load(masks_file)
for module_name in self.masks:
print('module_name: ', module_name)
config_list = self._build_config()
super().__init__(model, config_list)
def _build_config(self):
op_names = []
for module_name in self.masks:
op_names.append(module_name)
return [{'sparsity': 1, 'op_types': ['default', 'BatchNorm2d'], 'op_names': op_names}]
def calc_mask(self, layer, config, **kwargs):
"""
Directly return the corresponding mask
Parameters
----------
layer : LayerInfo
The layer to be pruned
config : dict
Pruning configurations for this weight
kwargs : dict
Auxiliary information
Returns
-------
dict
Mask of the layer
"""
assert layer.name in self.masks
return self.masks[layer.name]
...@@ -14,8 +14,11 @@ class LayerInfo: ...@@ -14,8 +14,11 @@ class LayerInfo:
self.name = name self.name = name
self.type = type(module).__name__ self.type = type(module).__name__
self._forward = None def _setattr(model, name, module):
name_list = name.split(".")
for name in name_list[:-1]:
model = getattr(model, name)
setattr(model, name_list[-1], module)
class Compressor: class Compressor:
""" """
...@@ -36,6 +39,9 @@ class Compressor: ...@@ -36,6 +39,9 @@ class Compressor:
self.bound_model = model self.bound_model = model
self.config_list = config_list self.config_list = config_list
self.modules_to_compress = None self.modules_to_compress = None
self.modules_wrapper = None
self.buffers = {}
self.is_wrapped = False
def detect_modules_to_compress(self): def detect_modules_to_compress(self):
""" """
...@@ -51,21 +57,60 @@ class Compressor: ...@@ -51,21 +57,60 @@ class Compressor:
self.modules_to_compress.append((layer, config)) self.modules_to_compress.append((layer, config))
return self.modules_to_compress return self.modules_to_compress
def _wrap_model(self):
"""
wrap all modules that needed to be compressed
"""
for wrapper in reversed(self.get_modules_wrapper()):
_setattr(self.bound_model, wrapper.name, wrapper)
self.is_wrapped = True
def _unwrap_model(self):
"""
unwrap all modules that needed to be compressed
"""
for wrapper in self.get_modules_wrapper():
_setattr(self.bound_model, wrapper.name, wrapper.module)
self.is_wrapped = False
def compress(self): def compress(self):
""" """
Compress the model with algorithm implemented by subclass. Compress the model with algorithm implemented by subclass.
The model will be instrumented and user should never edit it after calling this method. The model will be instrumented and user should never edit it after calling this method.
`self.modules_to_compress` records all the to-be-compressed layers `self.modules_to_compress` records all the to-be-compressed layers
Returns
-------
torch.nn.Module
model with specified modules compressed.
""" """
if self.modules_wrapper is not None:
# already compressed
return self.bound_model
else:
self.modules_wrapper = []
modules_to_compress = self.detect_modules_to_compress() modules_to_compress = self.detect_modules_to_compress()
for layer, config in modules_to_compress: for layer, config in modules_to_compress:
self._instrument_layer(layer, config) wrapper = self._wrap_modules(layer, config)
self.modules_wrapper.append(wrapper)
self._wrap_model()
return self.bound_model return self.bound_model
def register_buffer(self, name, value):
"""
To register buffers used in wrapped module's forward method.
"""
self.buffers[name] = value
def get_modules_to_compress(self): def get_modules_to_compress(self):
""" """
To obtain all the to-be-compressed layers. To obtain all the to-be-compressed modules.
Returns Returns
------- -------
...@@ -75,6 +120,17 @@ class Compressor: ...@@ -75,6 +120,17 @@ class Compressor:
""" """
return self.modules_to_compress return self.modules_to_compress
def get_modules_wrapper(self):
"""
To obtain all the wrapped modules.
Returns
-------
list
a list of the wrapped modules
"""
return self.modules_wrapper
def select_config(self, layer): def select_config(self, layer):
""" """
Find the configuration for `layer` by parsing `self.config_list` Find the configuration for `layer` by parsing `self.config_list`
...@@ -93,13 +149,24 @@ class Compressor: ...@@ -93,13 +149,24 @@ class Compressor:
ret = None ret = None
for config in self.config_list: for config in self.config_list:
config = config.copy() config = config.copy()
config['op_types'] = self._expand_config_op_types(config) # expand config if key `default` is in config['op_types']
if layer.type not in config['op_types']: if 'op_types' in config and 'default' in config['op_types']:
expanded_op_types = []
for op_type in config['op_types']:
if op_type == 'default':
expanded_op_types.extend(default_layers.weighted_modules)
else:
expanded_op_types.append(op_type)
config['op_types'] = expanded_op_types
# check if condition is satisified
if 'op_types' in config and layer.type not in config['op_types']:
continue continue
if config.get('op_names') and layer.name not in config['op_names']: if 'op_names' in config and layer.name not in config['op_names']:
continue continue
ret = config ret = config
if ret is None or ret.get('exclude'): if ret is None or 'exclude' in ret:
return None return None
return ret return ret
...@@ -119,7 +186,7 @@ class Compressor: ...@@ -119,7 +186,7 @@ class Compressor:
If user want to update model every step, user can override this method If user want to update model every step, user can override this method
""" """
def _instrument_layer(self, layer, config): def _wrap_modules(self, layer, config):
""" """
This method is implemented in the subclasses, i.e., `Pruner` and `Quantizer` This method is implemented in the subclasses, i.e., `Pruner` and `Quantizer`
...@@ -132,17 +199,66 @@ class Compressor: ...@@ -132,17 +199,66 @@ class Compressor:
""" """
raise NotImplementedError() raise NotImplementedError()
def _expand_config_op_types(self, config):
if config is None:
return []
expanded_op_types = []
for op_type in config.get('op_types', []):
if op_type == 'default':
expanded_op_types.extend(default_layers.weighted_modules)
else:
expanded_op_types.append(op_type)
return expanded_op_types
class PrunerModuleWrapper(torch.nn.Module):
def __init__(self, module, module_name, module_type, config, pruner):
"""
Wrap an module to enable data parallel, forward method customization and buffer registeration.
Parameters
----------
module : pytorch module
the module user wants to compress
config : dict
the configurations that users specify for compression
module_name : str
the name of the module to compress, wrapper module shares same name
module_type : str
the type of the module to compress
pruner : Pruner
the pruner used to calculate mask
"""
super().__init__()
# origin layer information
self.module = module
self.name = module_name
self.type = module_type
# config and pruner
self.config = config
self.pruner = pruner
self.registered_buffers = []
# register buffer for mask
self.register_buffer("weight_mask", torch.ones(self.module.weight.shape))
if hasattr(self.module, 'bias') and self.module.bias is not None:
self.register_buffer("bias_mask", torch.ones(self.module.bias.shape))
else:
self.register_buffer("bias_mask", None)
self.registered_buffers.append('weight_mask')
self.registered_buffers.append('bias_mask')
# register user specified buffer
for name in self.pruner.buffers:
self.register_buffer(name, self.pruner.buffers[name].clone())
self.registered_buffers.append(name)
def get_registered_buffers(self):
buffers = {}
for name in self.registered_buffers:
buffers[name] = getattr(self, name)
return buffers
def forward(self, *inputs):
mask = self.pruner.calc_mask(LayerInfo(self.name, self.module), self.config, **self.get_registered_buffers())
if mask is not None:
self.weight_mask.copy_(mask['weight'])
# apply mask to weight
self.module.weight.data = self.module.weight.data.mul_(self.weight_mask)
# apply mask to bias
if hasattr(self.module, 'bias') and self.module.bias is not None:
if mask is not None and 'bias' in mask:
self.bias_mask.copy_(mask['bias'])
self.module.bias.data = self.module.bias.data.mul_(self.bias_mask)
return self.module(*inputs)
class Pruner(Compressor): class Pruner(Compressor):
""" """
...@@ -158,9 +274,8 @@ class Pruner(Compressor): ...@@ -158,9 +274,8 @@ class Pruner(Compressor):
def __init__(self, model, config_list): def __init__(self, model, config_list):
super().__init__(model, config_list) super().__init__(model, config_list)
self.mask_dict = {}
def calc_mask(self, layer, config): def calc_mask(self, layer, config, **kwargs):
""" """
Pruners should overload this method to provide mask for weight tensors. Pruners should overload this method to provide mask for weight tensors.
The mask must have the same shape and type comparing to the weight. The mask must have the same shape and type comparing to the weight.
...@@ -176,9 +291,9 @@ class Pruner(Compressor): ...@@ -176,9 +291,9 @@ class Pruner(Compressor):
""" """
raise NotImplementedError("Pruners must overload calc_mask()") raise NotImplementedError("Pruners must overload calc_mask()")
def _instrument_layer(self, layer, config): def _wrap_modules(self, layer, config):
""" """
Create a wrapper forward function to replace the original one. Create a wrapper module to replace the original one.
Parameters Parameters
---------- ----------
...@@ -187,30 +302,14 @@ class Pruner(Compressor): ...@@ -187,30 +302,14 @@ class Pruner(Compressor):
config : dict config : dict
the configuration for generating the mask the configuration for generating the mask
""" """
assert layer._forward is None, 'Each model can only be compressed once' _logger.info("compressing module %s.", layer.name)
if not _check_weight(layer.module): wrapper = PrunerModuleWrapper(layer.module, layer.name, layer.type, config, self)
_logger.warning('Module %s does not have parameter "weight"', layer.name) assert hasattr(layer.module, 'weight'), "module %s does not have 'weight' attribute" % layer.name
return # move newly registered buffers to the same device of weight
layer._forward = layer.module.forward wrapper.to(layer.module.weight.device)
return wrapper
def new_forward(*inputs):
mask = self.calc_mask(layer, config) def export_model(self, model_path, mask_path=None, onnx_path=None, input_shape=None, device=None):
# apply mask to weight
old_weight = layer.module.weight.data
mask_weight = mask['weight']
layer.module.weight.data = old_weight.mul(mask_weight)
# apply mask to bias
if mask.__contains__('bias') and hasattr(layer.module, 'bias') and layer.module.bias is not None:
old_bias = layer.module.bias.data
mask_bias = mask['bias']
layer.module.bias.data = old_bias.mul(mask_bias)
# calculate forward
ret = layer._forward(*inputs)
return ret
layer.module.forward = new_forward
def export_model(self, model_path, mask_path=None, onnx_path=None, input_shape=None):
""" """
Export pruned model weights, masks and onnx model(optional) Export pruned model weights, masks and onnx model(optional)
...@@ -224,35 +323,144 @@ class Pruner(Compressor): ...@@ -224,35 +323,144 @@ class Pruner(Compressor):
(optional) path to save onnx model (optional) path to save onnx model
input_shape : list or tuple input_shape : list or tuple
input shape to onnx model input shape to onnx model
device : torch.device
device of the model, used to place the dummy input tensor for exporting onnx file.
the tensor is placed on cpu if ```device``` is None
""" """
if self.detect_modules_to_compress() and not self.mask_dict: # if self.detect_modules_to_compress() and not self.mask_dict:
_logger.warning('You may not use self.mask_dict in base Pruner class to record masks') # _logger.warning('You may not use self.mask_dict in base Pruner class to record masks')
assert model_path is not None, 'model_path must be specified' assert model_path is not None, 'model_path must be specified'
for name, m in self.bound_model.named_modules(): mask_dict = {}
if name == "": self._unwrap_model() # used for generating correct state_dict name without wrapper state
continue
masks = self.mask_dict.get(name) for wrapper in self.get_modules_wrapper():
if masks is not None: weight_mask = wrapper.weight_mask
mask_sum = masks['weight'].sum().item() bias_mask = wrapper.bias_mask
mask_num = masks['weight'].numel() if weight_mask is not None:
_logger.info('Layer: %s Sparsity: %.2f', name, 1 - mask_sum / mask_num) mask_sum = weight_mask.sum().item()
m.weight.data = m.weight.data.mul(masks['weight']) mask_num = weight_mask.numel()
if masks.__contains__('bias') and hasattr(m, 'bias') and m.bias is not None: _logger.info('Layer: %s Sparsity: %.2f', wrapper.name, 1 - mask_sum / mask_num)
m.bias.data = m.bias.data.mul(masks['bias']) wrapper.module.weight.data = wrapper.module.weight.data.mul(weight_mask)
else: if bias_mask is not None:
_logger.info('Layer: %s NOT compressed', name) wrapper.module.bias.data = wrapper.module.bias.data.mul(bias_mask)
# save mask to dict
mask_dict[wrapper.name] = {"weight": weight_mask, "bias": bias_mask}
torch.save(self.bound_model.state_dict(), model_path) torch.save(self.bound_model.state_dict(), model_path)
_logger.info('Model state_dict saved to %s', model_path) _logger.info('Model state_dict saved to %s', model_path)
if mask_path is not None: if mask_path is not None:
torch.save(self.mask_dict, mask_path) torch.save(mask_dict, mask_path)
_logger.info('Mask dict saved to %s', mask_path) _logger.info('Mask dict saved to %s', mask_path)
if onnx_path is not None: if onnx_path is not None:
assert input_shape is not None, 'input_shape must be specified to export onnx model' assert input_shape is not None, 'input_shape must be specified to export onnx model'
# input info needed # input info needed
if device is None:
device = torch.device('cpu')
input_data = torch.Tensor(*input_shape) input_data = torch.Tensor(*input_shape)
torch.onnx.export(self.bound_model, input_data, onnx_path) torch.onnx.export(self.bound_model, input_data.to(device), onnx_path)
_logger.info('Model in onnx with input shape %s saved to %s', input_data.shape, onnx_path) _logger.info('Model in onnx with input shape %s saved to %s', input_data.shape, onnx_path)
self._wrap_model()
def load_model_state_dict(self, model_state):
"""
Load the state dict saved from unwrapped model.
Parameters:
-----------
model_state : dict
state dict saved from unwrapped model
"""
if self.is_wrapped:
self._unwrap_model()
self.bound_model.load_state_dict(model_state)
self._wrap_model()
else:
self.bound_model.load_state_dict(model_state)
class QuantizerModuleWrapper(torch.nn.Module):
def __init__(self, module, module_name, module_type, config, quantizer):
"""
Wrap an module to enable data parallel, forward method customization and buffer registeration.
Parameters
----------
module : pytorch module
the module user wants to compress
config : dict
the configurations that users specify for compression
module_name : str
the name of the module to compress, wrapper module shares same name
module_type : str
the type of the module to compress
quantizer :quantizer
the quantizer used to calculate mask
"""
super().__init__()
# origin layer information
self.module = module
self.name = module_name
self.type = module_type
# config and pruner
self.config = config
self.quantizer = quantizer
self.registered_buffers = []
# register buffer and parameter
# old_weight is used to store origin weight and weight is used to store quantized weight
# the reason why weight is buffer instead of parameter is because in pytorch parameter is used as leaf
# if weight is leaf , then old_weight can not be updated.
if 'weight' in config['quant_types']:
if not _check_weight(self.module):
_logger.warning('Module %s does not have parameter "weight"', self.name)
else:
self.module.register_parameter('old_weight', torch.nn.Parameter(self.module.weight))
delattr(self.module, 'weight')
self.module.register_buffer('weight', self.module.old_weight)
# register user specified buffer
for name in self.quantizer.buffers:
self.register_buffer(name, self.quantizer.buffers[name].clone())
self.registered_buffers.append(name)
def get_registered_buffers(self):
buffers = {}
for name in self.registered_buffers:
buffers[name] = getattr(self, name)
return buffers
def forward(self, *inputs):
if 'input' in self.config['quant_types']:
inputs = self.quantizer.quant_grad.apply(
inputs,
QuantType.QUANT_INPUT,
self.quantizer.quantize_input,
self.config,
LayerInfo(self.name, self.module),
**self.get_registered_buffers())
if 'weight' in self.config['quant_types'] and _check_weight(self.module):
new_weight = self.quantizer.quant_grad.apply(
self.module.old_weight,
QuantType.QUANT_WEIGHT,
self.quantizer.quantize_weight,
self.config,
LayerInfo(self.name, self.module),
**self.get_registered_buffers())
self.module.weight = new_weight
result = self.module(*inputs)
else:
result = self.module(*inputs)
if 'output' in self.config['quant_types']:
result = self.quantizer.quant_grad.apply(
result,
QuantType.QUANT_OUTPUT,
self.quantizer.quantize_output,
self.config,
LayerInfo(self.name, self.module),
**self.get_registered_buffers())
return result
class Quantizer(Compressor): class Quantizer(Compressor):
""" """
...@@ -303,7 +511,7 @@ class Quantizer(Compressor): ...@@ -303,7 +511,7 @@ class Quantizer(Compressor):
raise NotImplementedError('Quantizer must overload quantize_input()') raise NotImplementedError('Quantizer must overload quantize_input()')
def _instrument_layer(self, layer, config): def _wrap_modules(self, layer, config):
""" """
Create a wrapper forward function to replace the original one. Create a wrapper forward function to replace the original one.
Parameters Parameters
...@@ -313,7 +521,6 @@ class Quantizer(Compressor): ...@@ -313,7 +521,6 @@ class Quantizer(Compressor):
config : dict config : dict
the configuration for quantization the configuration for quantization
""" """
assert layer._forward is None, 'Each model can only be compressed once'
assert 'quant_types' in config, 'must provide quant_types in config' assert 'quant_types' in config, 'must provide quant_types in config'
assert isinstance(config['quant_types'], list), 'quant_types must be list type' assert isinstance(config['quant_types'], list), 'quant_types must be list type'
assert 'quant_bits' in config, 'must provide quant_bits in config' assert 'quant_bits' in config, 'must provide quant_bits in config'
...@@ -323,35 +530,7 @@ class Quantizer(Compressor): ...@@ -323,35 +530,7 @@ class Quantizer(Compressor):
for quant_type in config['quant_types']: for quant_type in config['quant_types']:
assert quant_type in config['quant_bits'], 'bits length for %s must be specified in quant_bits dict' % quant_type assert quant_type in config['quant_bits'], 'bits length for %s must be specified in quant_bits dict' % quant_type
if 'weight' in config['quant_types']: return QuantizerModuleWrapper(layer.module, layer.name, layer.type, config, self)
if not _check_weight(layer.module):
_logger.warning('Module %s does not have parameter "weight"', layer.name)
else:
# old_weight is used to store origin weight and weight is used to store quantized weight
# the reason why weight is buffer instead of parameter is because in pytorch parameter is used as leaf
# if weight is leaf , then old_weight can not be updated.
layer.module.register_parameter('old_weight', torch.nn.Parameter(layer.module.weight))
delattr(layer.module, 'weight')
layer.module.register_buffer('weight', layer.module.old_weight)
layer._forward = layer.module.forward
def new_forward(*inputs):
if 'input' in config['quant_types']:
inputs = self.quant_grad.apply(inputs, QuantType.QUANT_INPUT, self.quantize_input, config, layer)
if 'weight' in config['quant_types'] and _check_weight(layer.module):
new_weight = self.quant_grad.apply(layer.module.old_weight, QuantType.QUANT_WEIGHT, self.quantize_weight, config, layer)
layer.module.weight = new_weight
result = layer._forward(*inputs)
else:
result = layer._forward(*inputs)
if 'output' in config['quant_types']:
result = self.quant_grad.apply(result, QuantType.QUANT_OUTPUT, self.quantize_output, config, layer)
return result
layer.module.forward = new_forward
class QuantType: class QuantType:
""" """
...@@ -387,19 +566,18 @@ class QuantGrad(torch.autograd.Function): ...@@ -387,19 +566,18 @@ class QuantGrad(torch.autograd.Function):
return grad_output return grad_output
@staticmethod @staticmethod
def forward(ctx, tensor, quant_type, quant_func, config, layer): def forward(ctx, tensor, quant_type, quant_func, config, layer, **kwargs):
ctx.save_for_backward(tensor, torch.Tensor([quant_type])) ctx.save_for_backward(tensor, torch.Tensor([quant_type]))
return quant_func(tensor, config, op=layer.module, op_type=layer.type, op_name=layer.name) return quant_func(tensor, config, op=layer.module, op_type=layer.type, op_name=layer.name, **kwargs)
@classmethod @classmethod
def backward(cls, ctx, grad_output): def backward(cls, ctx, grad_output):
tensor, quant_type = ctx.saved_variables tensor, quant_type = ctx.saved_variables
output = cls.quant_backward(tensor, grad_output, quant_type) output = cls.quant_backward(tensor, grad_output, quant_type)
return output, None, None, None, None return output, None, None, None, None, None
def _check_weight(module): def _check_weight(module):
try: try:
return isinstance(module.weight.data, torch.Tensor) return isinstance(module.weight.data, torch.Tensor)
except AttributeError: except AttributeError:
return False return False
\ No newline at end of file
...@@ -27,9 +27,9 @@ class LevelPruner(Pruner): ...@@ -27,9 +27,9 @@ class LevelPruner(Pruner):
""" """
super().__init__(model, config_list) super().__init__(model, config_list)
self.mask_calculated_ops = set() self.register_buffer("if_calculated", torch.tensor(0)) # pylint: disable=not-callable
def calc_mask(self, layer, config): def calc_mask(self, layer, config, **kwargs):
""" """
Calculate the mask of given layer Calculate the mask of given layer
Parameters Parameters
...@@ -45,8 +45,9 @@ class LevelPruner(Pruner): ...@@ -45,8 +45,9 @@ class LevelPruner(Pruner):
""" """
weight = layer.module.weight.data weight = layer.module.weight.data
op_name = layer.name if_calculated = kwargs["if_calculated"]
if op_name not in self.mask_calculated_ops:
if not if_calculated:
w_abs = weight.abs() w_abs = weight.abs()
k = int(weight.numel() * config['sparsity']) k = int(weight.numel() * config['sparsity'])
if k == 0: if k == 0:
...@@ -54,12 +55,10 @@ class LevelPruner(Pruner): ...@@ -54,12 +55,10 @@ class LevelPruner(Pruner):
threshold = torch.topk(w_abs.view(-1), k, largest=False)[0].max() threshold = torch.topk(w_abs.view(-1), k, largest=False)[0].max()
mask_weight = torch.gt(w_abs, threshold).type_as(weight) mask_weight = torch.gt(w_abs, threshold).type_as(weight)
mask = {'weight': mask_weight} mask = {'weight': mask_weight}
self.mask_dict.update({op_name: mask}) if_calculated.copy_(torch.tensor(1)) # pylint: disable=not-callable
self.mask_calculated_ops.add(op_name) return mask
else: else:
assert op_name in self.mask_dict, "op_name not in the mask_dict" return None
mask = self.mask_dict[op_name]
return mask
class AGP_Pruner(Pruner): class AGP_Pruner(Pruner):
...@@ -84,17 +83,20 @@ class AGP_Pruner(Pruner): ...@@ -84,17 +83,20 @@ class AGP_Pruner(Pruner):
super().__init__(model, config_list) super().__init__(model, config_list)
self.now_epoch = 0 self.now_epoch = 0
self.if_init_list = {} self.register_buffer("if_calculated", torch.tensor(0)) # pylint: disable=not-callable
def calc_mask(self, layer, config): def calc_mask(self, layer, config, **kwargs):
""" """
Calculate the mask of given layer Calculate the mask of given layer.
Scale factors with the smallest absolute value in the BN layer are masked.
Parameters Parameters
---------- ----------
layer : LayerInfo layer : LayerInfo
the layer to instrument the compression operation the layer to instrument the compression operation
config : dict config : dict
layer's pruning config layer's pruning config
kwargs: dict
buffers registered in __init__ function
Returns Returns
------- -------
dict dict
...@@ -102,24 +104,26 @@ class AGP_Pruner(Pruner): ...@@ -102,24 +104,26 @@ class AGP_Pruner(Pruner):
""" """
weight = layer.module.weight.data weight = layer.module.weight.data
op_name = layer.name
start_epoch = config.get('start_epoch', 0) start_epoch = config.get('start_epoch', 0)
freq = config.get('frequency', 1) freq = config.get('frequency', 1)
if self.now_epoch >= start_epoch and self.if_init_list.get(op_name, True) \
and (self.now_epoch - start_epoch) % freq == 0: if_calculated = kwargs["if_calculated"]
mask = self.mask_dict.get(op_name, {'weight': torch.ones(weight.shape).type_as(weight)}) if if_calculated:
target_sparsity = self.compute_target_sparsity(config) return None
k = int(weight.numel() * target_sparsity) if not (self.now_epoch >= start_epoch and (self.now_epoch - start_epoch) % freq == 0):
if k == 0 or target_sparsity >= 1 or target_sparsity <= 0: return None
return mask
# if we want to generate new mask, we should update weigth first mask = {'weight': kwargs['weight_mask'] if 'weight_mask' in kwargs else torch.ones(weight.shape).type_as(weight)}
w_abs = weight.abs() * mask['weight'] target_sparsity = self.compute_target_sparsity(config)
threshold = torch.topk(w_abs.view(-1), k, largest=False)[0].max() k = int(weight.numel() * target_sparsity)
new_mask = {'weight': torch.gt(w_abs, threshold).type_as(weight)} if k == 0 or target_sparsity >= 1 or target_sparsity <= 0:
self.mask_dict.update({op_name: new_mask}) return mask
self.if_init_list.update({op_name: False}) # if we want to generate new mask, we should update weigth first
else: w_abs = weight.abs() * mask['weight']
new_mask = self.mask_dict.get(op_name, {'weight': torch.ones(weight.shape).type_as(weight)}) threshold = torch.topk(w_abs.view(-1), k, largest=False)[0].max()
new_mask = {'weight': torch.gt(w_abs, threshold).type_as(weight)}
if_calculated.copy_(torch.tensor(1)) # pylint: disable=not-callable
return new_mask return new_mask
def compute_target_sparsity(self, config): def compute_target_sparsity(self, config):
...@@ -165,9 +169,8 @@ class AGP_Pruner(Pruner): ...@@ -165,9 +169,8 @@ class AGP_Pruner(Pruner):
if epoch > 0: if epoch > 0:
self.now_epoch = epoch self.now_epoch = epoch
for k in self.if_init_list.keys(): for wrapper in self.get_modules_wrapper():
self.if_init_list[k] = True wrapper.if_calculated.copy_(torch.tensor(0)) # pylint: disable=not-callable
class SlimPruner(Pruner): class SlimPruner(Pruner):
""" """
...@@ -187,7 +190,6 @@ class SlimPruner(Pruner): ...@@ -187,7 +190,6 @@ class SlimPruner(Pruner):
""" """
super().__init__(model, config_list) super().__init__(model, config_list)
self.mask_calculated_ops = set()
weight_list = [] weight_list = []
if len(config_list) > 1: if len(config_list) > 1:
logger.warning('Slim pruner only supports 1 configuration') logger.warning('Slim pruner only supports 1 configuration')
...@@ -198,8 +200,9 @@ class SlimPruner(Pruner): ...@@ -198,8 +200,9 @@ class SlimPruner(Pruner):
all_bn_weights = torch.cat(weight_list) all_bn_weights = torch.cat(weight_list)
k = int(all_bn_weights.shape[0] * config['sparsity']) k = int(all_bn_weights.shape[0] * config['sparsity'])
self.global_threshold = torch.topk(all_bn_weights.view(-1), k, largest=False)[0].max() self.global_threshold = torch.topk(all_bn_weights.view(-1), k, largest=False)[0].max()
self.register_buffer("if_calculated", torch.tensor(0)) # pylint: disable=not-callable
def calc_mask(self, layer, config): def calc_mask(self, layer, config, **kwargs):
""" """
Calculate the mask of given layer. Calculate the mask of given layer.
Scale factors with the smallest absolute value in the BN layer are masked. Scale factors with the smallest absolute value in the BN layer are masked.
...@@ -209,6 +212,8 @@ class SlimPruner(Pruner): ...@@ -209,6 +212,8 @@ class SlimPruner(Pruner):
the layer to instrument the compression operation the layer to instrument the compression operation
config : dict config : dict
layer's pruning config layer's pruning config
kwargs: dict
buffers registered in __init__ function
Returns Returns
------- -------
dict dict
...@@ -216,27 +221,21 @@ class SlimPruner(Pruner): ...@@ -216,27 +221,21 @@ class SlimPruner(Pruner):
""" """
weight = layer.module.weight.data weight = layer.module.weight.data
op_name = layer.name
op_type = layer.type op_type = layer.type
if_calculated = kwargs["if_calculated"]
assert op_type == 'BatchNorm2d', 'SlimPruner only supports 2d batch normalization layer pruning' assert op_type == 'BatchNorm2d', 'SlimPruner only supports 2d batch normalization layer pruning'
if op_name in self.mask_calculated_ops: if if_calculated:
assert op_name in self.mask_dict return None
return self.mask_dict.get(op_name)
base_mask = torch.ones(weight.size()).type_as(weight).detach() base_mask = torch.ones(weight.size()).type_as(weight).detach()
mask = {'weight': base_mask.detach(), 'bias': base_mask.clone().detach()} mask = {'weight': base_mask.detach(), 'bias': base_mask.clone().detach()}
try: filters = weight.size(0)
filters = weight.size(0) num_prune = int(filters * config.get('sparsity'))
num_prune = int(filters * config.get('sparsity')) if filters >= 2 and num_prune >= 1:
if filters < 2 or num_prune < 1:
return mask
w_abs = weight.abs() w_abs = weight.abs()
mask_weight = torch.gt(w_abs, self.global_threshold).type_as(weight) mask_weight = torch.gt(w_abs, self.global_threshold).type_as(weight)
mask_bias = mask_weight.clone() mask_bias = mask_weight.clone()
mask = {'weight': mask_weight.detach(), 'bias': mask_bias.detach()} mask = {'weight': mask_weight.detach(), 'bias': mask_bias.detach()}
finally: if_calculated.copy_(torch.tensor(1)) # pylint: disable=not-callable
self.mask_dict.update({layer.name: mask})
self.mask_calculated_ops.add(layer.name)
return mask return mask
class LotteryTicketPruner(Pruner): class LotteryTicketPruner(Pruner):
...@@ -294,38 +293,23 @@ class LotteryTicketPruner(Pruner): ...@@ -294,38 +293,23 @@ class LotteryTicketPruner(Pruner):
prune_iterations = config['prune_iterations'] prune_iterations = config['prune_iterations']
return prune_iterations return prune_iterations
def _print_masks(self, print_mask=False):
torch.set_printoptions(threshold=1000)
for op_name in self.mask_dict.keys():
mask = self.mask_dict[op_name]
print('op name: ', op_name)
if print_mask:
print('mask: ', mask)
# calculate current sparsity
mask_num = mask['weight'].sum().item()
mask_size = mask['weight'].numel()
print('sparsity: ', 1 - mask_num / mask_size)
torch.set_printoptions(profile='default')
def _calc_sparsity(self, sparsity): def _calc_sparsity(self, sparsity):
keep_ratio_once = (1 - sparsity) ** (1 / self.prune_iterations) keep_ratio_once = (1 - sparsity) ** (1 / self.prune_iterations)
curr_keep_ratio = keep_ratio_once ** self.curr_prune_iteration curr_keep_ratio = keep_ratio_once ** self.curr_prune_iteration
return max(1 - curr_keep_ratio, 0) return max(1 - curr_keep_ratio, 0)
def _calc_mask(self, weight, sparsity, op_name): def _calc_mask(self, weight, sparsity, curr_w_mask):
if self.curr_prune_iteration == 0: if self.curr_prune_iteration == 0:
mask = torch.ones(weight.shape).type_as(weight) mask = torch.ones(weight.shape).type_as(weight)
else: else:
curr_sparsity = self._calc_sparsity(sparsity) curr_sparsity = self._calc_sparsity(sparsity)
assert self.mask_dict.get(op_name) is not None w_abs = weight.abs() * curr_w_mask
curr_mask = self.mask_dict.get(op_name)
w_abs = weight.abs() * curr_mask['weight']
k = int(w_abs.numel() * curr_sparsity) k = int(w_abs.numel() * curr_sparsity)
threshold = torch.topk(w_abs.view(-1), k, largest=False).values.max() threshold = torch.topk(w_abs.view(-1), k, largest=False).values.max()
mask = torch.gt(w_abs, threshold).type_as(weight) mask = torch.gt(w_abs, threshold).type_as(weight)
return {'weight': mask} return {'weight': mask}
def calc_mask(self, layer, config): def calc_mask(self, layer, config, **kwargs):
""" """
Generate mask for the given ``weight``. Generate mask for the given ``weight``.
...@@ -335,15 +319,17 @@ class LotteryTicketPruner(Pruner): ...@@ -335,15 +319,17 @@ class LotteryTicketPruner(Pruner):
The layer to be pruned The layer to be pruned
config : dict config : dict
Pruning configurations for this weight Pruning configurations for this weight
kwargs : dict
Auxiliary information
Returns Returns
------- -------
tensor tensor
The mask for this weight The mask for this weight, it is ```None``` because this pruner
calculates and assigns masks in ```prune_iteration_start```,
no need to do anything in this function.
""" """
assert self.mask_dict.get(layer.name) is not None, 'Please call iteration_start before training' return None
mask = self.mask_dict[layer.name]
return mask
def get_prune_iterations(self): def get_prune_iterations(self):
""" """
...@@ -368,16 +354,26 @@ class LotteryTicketPruner(Pruner): ...@@ -368,16 +354,26 @@ class LotteryTicketPruner(Pruner):
self.curr_prune_iteration += 1 self.curr_prune_iteration += 1
assert self.curr_prune_iteration < self.prune_iterations + 1, 'Exceed the configured prune_iterations' assert self.curr_prune_iteration < self.prune_iterations + 1, 'Exceed the configured prune_iterations'
modules_wrapper = self.get_modules_wrapper()
modules_to_compress = self.detect_modules_to_compress() modules_to_compress = self.detect_modules_to_compress()
for layer, config in modules_to_compress: for layer, config in modules_to_compress:
module_wrapper = None
for wrapper in modules_wrapper:
if wrapper.name == layer.name:
module_wrapper = wrapper
break
assert module_wrapper is not None
sparsity = config.get('sparsity') sparsity = config.get('sparsity')
mask = self._calc_mask(layer.module.weight.data, sparsity, layer.name) mask = self._calc_mask(layer.module.weight.data, sparsity, module_wrapper.weight_mask)
self.mask_dict.update({layer.name: mask}) # TODO: directly use weight_mask is not good
self._print_masks() module_wrapper.weight_mask.copy_(mask['weight'])
# there is no mask for bias
# reinit weights back to original after new masks are generated # reinit weights back to original after new masks are generated
if self.reset_weights: if self.reset_weights:
self._model.load_state_dict(self._model_state) # should use this member function to reset model weights
self.load_model_state_dict(self._model_state)
self._optimizer.load_state_dict(self._optimizer_state) self._optimizer.load_state_dict(self._optimizer_state)
if self._lr_scheduler is not None: if self._lr_scheduler is not None:
self._lr_scheduler.load_state_dict(self._scheduler_state) self._lr_scheduler.load_state_dict(self._scheduler_state)
...@@ -27,12 +27,12 @@ class WeightRankFilterPruner(Pruner): ...@@ -27,12 +27,12 @@ class WeightRankFilterPruner(Pruner):
""" """
super().__init__(model, config_list) super().__init__(model, config_list)
self.mask_calculated_ops = set() # operations whose mask has been calculated self.register_buffer("if_calculated", torch.tensor(0)) # pylint: disable=not-callable
def get_mask(self, base_mask, weight, num_prune): def get_mask(self, base_mask, weight, num_prune):
raise NotImplementedError('{} get_mask is not implemented'.format(self.__class__.__name__)) raise NotImplementedError('{} get_mask is not implemented'.format(self.__class__.__name__))
def calc_mask(self, layer, config): def calc_mask(self, layer, config, **kwargs):
""" """
Calculate the mask of given layer. Calculate the mask of given layer.
Filters with the smallest importance criterion of the kernel weights are masked. Filters with the smallest importance criterion of the kernel weights are masked.
...@@ -49,14 +49,13 @@ class WeightRankFilterPruner(Pruner): ...@@ -49,14 +49,13 @@ class WeightRankFilterPruner(Pruner):
""" """
weight = layer.module.weight.data weight = layer.module.weight.data
op_name = layer.name
op_type = layer.type op_type = layer.type
assert 0 <= config.get('sparsity') < 1, "sparsity must in the range [0, 1)" assert 0 <= config.get('sparsity') < 1, "sparsity must in the range [0, 1)"
assert op_type in ['Conv1d', 'Conv2d'], "only support Conv1d and Conv2d" assert op_type in ['Conv1d', 'Conv2d'], "only support Conv1d and Conv2d"
assert op_type in config.get('op_types') assert op_type in config.get('op_types')
if op_name in self.mask_calculated_ops: if_calculated = kwargs["if_calculated"]
assert op_name in self.mask_dict if if_calculated:
return self.mask_dict.get(op_name) return None
mask_weight = torch.ones(weight.size()).type_as(weight).detach() mask_weight = torch.ones(weight.size()).type_as(weight).detach()
if hasattr(layer.module, 'bias') and layer.module.bias is not None: if hasattr(layer.module, 'bias') and layer.module.bias is not None:
mask_bias = torch.ones(layer.module.bias.size()).type_as(layer.module.bias).detach() mask_bias = torch.ones(layer.module.bias.size()).type_as(layer.module.bias).detach()
...@@ -70,8 +69,7 @@ class WeightRankFilterPruner(Pruner): ...@@ -70,8 +69,7 @@ class WeightRankFilterPruner(Pruner):
return mask return mask
mask = self.get_mask(mask, weight, num_prune) mask = self.get_mask(mask, weight, num_prune)
finally: finally:
self.mask_dict.update({op_name: mask}) if_calculated.copy_(torch.tensor(1)) # pylint: disable=not-callable
self.mask_calculated_ops.add(op_name)
return mask return mask
...@@ -259,4 +257,5 @@ class FPGMPruner(WeightRankFilterPruner): ...@@ -259,4 +257,5 @@ class FPGMPruner(WeightRankFilterPruner):
return x.sum() return x.sum()
def update_epoch(self, epoch): def update_epoch(self, epoch):
self.mask_calculated_ops = set() for wrapper in self.get_modules_wrapper():
wrapper.registered_buffers['if_calculated'].copy_(torch.tensor(0)) # pylint: disable=not-callable
...@@ -380,6 +380,8 @@ class Hyperband(MsgDispatcherBase): ...@@ -380,6 +380,8 @@ class Hyperband(MsgDispatcherBase):
ValueError ValueError
Data type not supported Data type not supported
""" """
if 'value' in data:
data['value'] = json_tricks.loads(data['value'])
if data['type'] == MetricType.REQUEST_PARAMETER: if data['type'] == MetricType.REQUEST_PARAMETER:
assert multi_phase_enabled() assert multi_phase_enabled()
assert data['trial_job_id'] is not None assert data['trial_job_id'] is not None
......
...@@ -31,11 +31,11 @@ def test(): ...@@ -31,11 +31,11 @@ def test():
# [1,1,1,1,1,1,1,1,1,1], # [1,1,1,1,1,1,1,1,1,1],
# [1,1,1,1,1,1,1,1,1,1]] # [1,1,1,1,1,1,1,1,1,1]]
assessor = MedianstopAssessor(FLAGS.start_step, FLAGS.optimize_mode) assessor = MedianstopAssessor(FLAGS.optimize_mode, FLAGS.start_step)
for i in range(4): for i in range(len(lcs)):
#lc = [] #lc = []
to_complete = True to_complete = True
for k in range(10): for k in range(len(lcs[0])):
#d = random.randint(i*100+0, i*100+100) #d = random.randint(i*100+0, i*100+100)
#lc.append(d) #lc.append(d)
ret = assessor.assess_trial(i, lcs[i][:k+1]) ret = assessor.assess_trial(i, lcs[i][:k+1])
......
...@@ -113,6 +113,8 @@ class MsgDispatcher(MsgDispatcherBase): ...@@ -113,6 +113,8 @@ class MsgDispatcher(MsgDispatcherBase):
"""Import additional data for tuning """Import additional data for tuning
data: a list of dictionaries, each of which has at least two keys, 'parameter' and 'value' data: a list of dictionaries, each of which has at least two keys, 'parameter' and 'value'
""" """
for entry in data:
entry['value'] = json_tricks.loads(entry['value'])
self.tuner.import_data(data) self.tuner.import_data(data)
def handle_add_customized_trial(self, data): def handle_add_customized_trial(self, data):
...@@ -127,6 +129,9 @@ class MsgDispatcher(MsgDispatcherBase): ...@@ -127,6 +129,9 @@ class MsgDispatcher(MsgDispatcherBase):
- 'value': metric value reported by nni.report_final_result() - 'value': metric value reported by nni.report_final_result()
- 'type': report type, support {'FINAL', 'PERIODICAL'} - 'type': report type, support {'FINAL', 'PERIODICAL'}
""" """
# metrics value is dumped as json string in trial, so we need to decode it here
if 'value' in data:
data['value'] = json_tricks.loads(data['value'])
if data['type'] == MetricType.FINAL: if data['type'] == MetricType.FINAL:
self._handle_final_metric_data(data) self._handle_final_metric_data(data)
elif data['type'] == MetricType.PERIODICAL: elif data['type'] == MetricType.PERIODICAL:
......
...@@ -13,7 +13,12 @@ logger = logging.getLogger(__name__) ...@@ -13,7 +13,12 @@ logger = logging.getLogger(__name__)
class BaseMutator(nn.Module): class BaseMutator(nn.Module):
""" """
A mutator is responsible for mutating a graph by obtaining the search space from the network and implementing A mutator is responsible for mutating a graph by obtaining the search space from the network and implementing
callbacks that are called in ``forward`` in Mutables. callbacks that are called in ``forward`` in mutables.
Parameters
----------
model : nn.Module
PyTorch model to apply mutator on.
""" """
def __init__(self, model): def __init__(self, model):
...@@ -52,9 +57,23 @@ class BaseMutator(nn.Module): ...@@ -52,9 +57,23 @@ class BaseMutator(nn.Module):
@property @property
def mutables(self): def mutables(self):
"""
A generator of all modules inheriting :class:`~nni.nas.pytorch.mutables.Mutable`.
Modules are yielded in the order that they are defined in ``__init__``.
For mutables with their keys appearing multiple times, only the first one will appear.
"""
return self._structured_mutables return self._structured_mutables
@property
def undedup_mutables(self):
return self._structured_mutables.traverse(deduplicate=False)
def forward(self, *inputs): def forward(self, *inputs):
"""
Warnings
--------
Don't call forward of a mutator.
"""
raise RuntimeError("Forward is undefined for mutators.") raise RuntimeError("Forward is undefined for mutators.")
def __setattr__(self, name, value): def __setattr__(self, name, value):
...@@ -70,6 +89,7 @@ class BaseMutator(nn.Module): ...@@ -70,6 +89,7 @@ class BaseMutator(nn.Module):
Parameters Parameters
---------- ----------
mutable_scope : MutableScope mutable_scope : MutableScope
The mutable scope that is entered.
""" """
pass pass
...@@ -80,6 +100,7 @@ class BaseMutator(nn.Module): ...@@ -80,6 +100,7 @@ class BaseMutator(nn.Module):
Parameters Parameters
---------- ----------
mutable_scope : MutableScope mutable_scope : MutableScope
The mutable scope that is exited.
""" """
pass pass
...@@ -90,12 +111,14 @@ class BaseMutator(nn.Module): ...@@ -90,12 +111,14 @@ class BaseMutator(nn.Module):
Parameters Parameters
---------- ----------
mutable : LayerChoice mutable : LayerChoice
Module whose forward is called.
inputs : list of torch.Tensor inputs : list of torch.Tensor
The arguments of its forward function.
Returns Returns
------- -------
tuple of torch.Tensor and torch.Tensor tuple of torch.Tensor and torch.Tensor
output tensor and mask Output tensor and mask.
""" """
raise NotImplementedError raise NotImplementedError
...@@ -106,12 +129,14 @@ class BaseMutator(nn.Module): ...@@ -106,12 +129,14 @@ class BaseMutator(nn.Module):
Parameters Parameters
---------- ----------
mutable : InputChoice mutable : InputChoice
Mutable that is called.
tensor_list : list of torch.Tensor tensor_list : list of torch.Tensor
The arguments mutable is called with.
Returns Returns
------- -------
tuple of torch.Tensor and torch.Tensor tuple of torch.Tensor and torch.Tensor
output tensor and mask Output tensor and mask.
""" """
raise NotImplementedError raise NotImplementedError
...@@ -123,5 +148,6 @@ class BaseMutator(nn.Module): ...@@ -123,5 +148,6 @@ class BaseMutator(nn.Module):
Returns Returns
------- -------
dict dict
Mappings from mutable keys to decisions.
""" """
raise NotImplementedError raise NotImplementedError
...@@ -8,16 +8,33 @@ class BaseTrainer(ABC): ...@@ -8,16 +8,33 @@ class BaseTrainer(ABC):
@abstractmethod @abstractmethod
def train(self): def train(self):
"""
Override the method to train.
"""
raise NotImplementedError raise NotImplementedError
@abstractmethod @abstractmethod
def validate(self): def validate(self):
"""
Override the method to validate.
"""
raise NotImplementedError raise NotImplementedError
@abstractmethod @abstractmethod
def export(self, file): def export(self, file):
"""
Override the method to export to file.
Parameters
----------
file : str
File path to export to.
"""
raise NotImplementedError raise NotImplementedError
@abstractmethod @abstractmethod
def checkpoint(self): def checkpoint(self):
"""
Override to dump a checkpoint.
"""
raise NotImplementedError raise NotImplementedError
...@@ -11,6 +11,9 @@ _logger = logging.getLogger(__name__) ...@@ -11,6 +11,9 @@ _logger = logging.getLogger(__name__)
class Callback: class Callback:
"""
Callback provides an easy way to react to events like begin/end of epochs.
"""
def __init__(self): def __init__(self):
self.model = None self.model = None
...@@ -18,14 +21,42 @@ class Callback: ...@@ -18,14 +21,42 @@ class Callback:
self.trainer = None self.trainer = None
def build(self, model, mutator, trainer): def build(self, model, mutator, trainer):
"""
Callback needs to be built with model, mutator, trainer, to get updates from them.
Parameters
----------
model : nn.Module
Model to be trained.
mutator : nn.Module
Mutator that mutates the model.
trainer : BaseTrainer
Trainer that is to call the callback.
"""
self.model = model self.model = model
self.mutator = mutator self.mutator = mutator
self.trainer = trainer self.trainer = trainer
def on_epoch_begin(self, epoch): def on_epoch_begin(self, epoch):
"""
Implement this to do something at the begin of epoch.
Parameters
----------
epoch : int
Epoch number, starting from 0.
"""
pass pass
def on_epoch_end(self, epoch): def on_epoch_end(self, epoch):
"""
Implement this to do something at the end of epoch.
Parameters
----------
epoch : int
Epoch number, starting from 0.
"""
pass pass
def on_batch_begin(self, epoch): def on_batch_begin(self, epoch):
...@@ -36,6 +67,14 @@ class Callback: ...@@ -36,6 +67,14 @@ class Callback:
class LRSchedulerCallback(Callback): class LRSchedulerCallback(Callback):
"""
Calls scheduler on every epoch ends.
Parameters
----------
scheduler : LRScheduler
Scheduler to be called.
"""
def __init__(self, scheduler, mode="epoch"): def __init__(self, scheduler, mode="epoch"):
super().__init__() super().__init__()
assert mode == "epoch" assert mode == "epoch"
...@@ -43,28 +82,54 @@ class LRSchedulerCallback(Callback): ...@@ -43,28 +82,54 @@ class LRSchedulerCallback(Callback):
self.mode = mode self.mode = mode
def on_epoch_end(self, epoch): def on_epoch_end(self, epoch):
"""
Call ``self.scheduler.step()`` on epoch end.
"""
self.scheduler.step() self.scheduler.step()
class ArchitectureCheckpoint(Callback): class ArchitectureCheckpoint(Callback):
"""
Calls ``trainer.export()`` on every epoch ends.
Parameters
----------
checkpoint_dir : str
Location to save checkpoints.
"""
def __init__(self, checkpoint_dir): def __init__(self, checkpoint_dir):
super().__init__() super().__init__()
self.checkpoint_dir = checkpoint_dir self.checkpoint_dir = checkpoint_dir
os.makedirs(self.checkpoint_dir, exist_ok=True) os.makedirs(self.checkpoint_dir, exist_ok=True)
def on_epoch_end(self, epoch): def on_epoch_end(self, epoch):
"""
Dump to ``/checkpoint_dir/epoch_{number}.json`` on epoch end.
"""
dest_path = os.path.join(self.checkpoint_dir, "epoch_{}.json".format(epoch)) dest_path = os.path.join(self.checkpoint_dir, "epoch_{}.json".format(epoch))
_logger.info("Saving architecture to %s", dest_path) _logger.info("Saving architecture to %s", dest_path)
self.trainer.export(dest_path) self.trainer.export(dest_path)
class ModelCheckpoint(Callback): class ModelCheckpoint(Callback):
"""
Calls ``trainer.export()`` on every epoch ends.
Parameters
----------
checkpoint_dir : str
Location to save checkpoints.
"""
def __init__(self, checkpoint_dir): def __init__(self, checkpoint_dir):
super().__init__() super().__init__()
self.checkpoint_dir = checkpoint_dir self.checkpoint_dir = checkpoint_dir
os.makedirs(self.checkpoint_dir, exist_ok=True) os.makedirs(self.checkpoint_dir, exist_ok=True)
def on_epoch_end(self, epoch): def on_epoch_end(self, epoch):
"""
Dump to ``/checkpoint_dir/epoch_{number}.pth.tar`` on every epoch end.
``DataParallel`` object will have their inside modules exported.
"""
if isinstance(self.model, nn.DataParallel): if isinstance(self.model, nn.DataParallel):
state_dict = self.model.module.state_dict() state_dict = self.model.module.state_dict()
else: else:
......
...@@ -127,18 +127,15 @@ class RegularizedMutatorParallel(DistributedDataParallel): ...@@ -127,18 +127,15 @@ class RegularizedMutatorParallel(DistributedDataParallel):
class DartsDiscreteMutator(Mutator): class DartsDiscreteMutator(Mutator):
""" """
A mutator that applies the final sampling result of a parent mutator on another model to train. A mutator that applies the final sampling result of a parent mutator on another model to train.
Parameters
----------
model : nn.Module
The model to apply the mutator.
parent_mutator : Mutator
The mutator that provides ``sample_final`` method, that will be called to get the architecture.
""" """
def __init__(self, model, parent_mutator): def __init__(self, model, parent_mutator):
"""
Initialization.
Parameters
----------
model : nn.Module
The model to apply the mutator.
parent_mutator : Mutator
The mutator that provides ``sample_final`` method, that will be called to get the architecture.
"""
super().__init__(model) super().__init__(model)
self.__dict__["parent_mutator"] = parent_mutator # avoid parameters to be included self.__dict__["parent_mutator"] = parent_mutator # avoid parameters to be included
......
...@@ -32,73 +32,73 @@ class InteractiveKLLoss(nn.Module): ...@@ -32,73 +32,73 @@ class InteractiveKLLoss(nn.Module):
class CdartsTrainer(object): class CdartsTrainer(object):
"""
CDARTS trainer.
Parameters
----------
model_small : nn.Module
PyTorch model to be trained. This is the search network of CDARTS.
model_large : nn.Module
PyTorch model to be trained. This is the evaluation network of CDARTS.
criterion : callable
Receives logits and ground truth label, return a loss tensor, e.g., ``nn.CrossEntropyLoss()``.
loaders : list of torch.utils.data.DataLoader
List of train data and valid data loaders, for training weights and architecture weights respectively.
samplers : list of torch.utils.data.Sampler
List of train data and valid data samplers. This can be PyTorch standard samplers if not distributed.
In distributed mode, sampler needs to have ``set_epoch`` method. Refer to data utils in CDARTS example for details.
logger : logging.Logger
The logger for logging. Will use nni logger by default (if logger is ``None``).
regular_coeff : float
The coefficient of regular loss.
regular_ratio : float
The ratio of regular loss.
warmup_epochs : int
The epochs to warmup the search network
fix_head : bool
``True`` if fixing the paramters of auxiliary heads, else unfix the paramters of auxiliary heads.
epochs : int
Number of epochs planned for training.
steps_per_epoch : int
Steps of one epoch.
loss_alpha : float
The loss coefficient.
loss_T : float
The loss coefficient.
distributed : bool
``True`` if using distributed training, else non-distributed training.
log_frequency : int
Step count per logging.
grad_clip : float
Gradient clipping for weights.
interactive_type : string
``kl`` or ``smoothl1``.
output_path : string
Log storage path.
w_lr : float
Learning rate of the search network parameters.
w_momentum : float
Momentum of the search and the evaluation network.
w_weight_decay : float
The weight decay the search and the evaluation network parameters.
alpha_lr : float
Learning rate of the architecture parameters.
alpha_weight_decay : float
The weight decay the architecture parameters.
nasnet_lr : float
Learning rate of the evaluation network parameters.
local_rank : int
The number of thread.
share_module : bool
``True`` if sharing the stem and auxiliary heads, else not sharing these modules.
"""
def __init__(self, model_small, model_large, criterion, loaders, samplers, logger=None, def __init__(self, model_small, model_large, criterion, loaders, samplers, logger=None,
regular_coeff=5, regular_ratio=0.2, warmup_epochs=2, fix_head=True, regular_coeff=5, regular_ratio=0.2, warmup_epochs=2, fix_head=True,
epochs=32, steps_per_epoch=None, loss_alpha=2, loss_T=2, distributed=True, epochs=32, steps_per_epoch=None, loss_alpha=2, loss_T=2, distributed=True,
log_frequency=10, grad_clip=5.0, interactive_type='kl', output_path='./outputs', log_frequency=10, grad_clip=5.0, interactive_type='kl', output_path='./outputs',
w_lr=0.2, w_momentum=0.9, w_weight_decay=3e-4, alpha_lr=0.2, alpha_weight_decay=1e-4, w_lr=0.2, w_momentum=0.9, w_weight_decay=3e-4, alpha_lr=0.2, alpha_weight_decay=1e-4,
nasnet_lr=0.2, local_rank=0, share_module=True): nasnet_lr=0.2, local_rank=0, share_module=True):
"""
Initialize a CdartsTrainer.
Parameters
----------
model_small : nn.Module
PyTorch model to be trained. This is the search network of CDARTS.
model_large : nn.Module
PyTorch model to be trained. This is the evaluation network of CDARTS.
criterion : callable
Receives logits and ground truth label, return a loss tensor, e.g., ``nn.CrossEntropyLoss()``.
loaders : list of torch.utils.data.DataLoader
List of train data and valid data loaders, for training weights and architecture weights respectively.
samplers : list of torch.utils.data.Sampler
List of train data and valid data samplers. This can be PyTorch standard samplers if not distributed.
In distributed mode, sampler needs to have ``set_epoch`` method. Refer to data utils in CDARTS example for details.
logger : logging.Logger
The logger for logging. Will use nni logger by default (if logger is ``None``).
regular_coeff : float
The coefficient of regular loss.
regular_ratio : float
The ratio of regular loss.
warmup_epochs : int
The epochs to warmup the search network
fix_head : bool
``True`` if fixing the paramters of auxiliary heads, else unfix the paramters of auxiliary heads.
epochs : int
Number of epochs planned for training.
steps_per_epoch : int
Steps of one epoch.
loss_alpha : float
The loss coefficient.
loss_T : float
The loss coefficient.
distributed : bool
``True`` if using distributed training, else non-distributed training.
log_frequency : int
Step count per logging.
grad_clip : float
Gradient clipping for weights.
interactive_type : string
``kl`` or ``smoothl1``.
output_path : string
Log storage path.
w_lr : float
Learning rate of the search network parameters.
w_momentum : float
Momentum of the search and the evaluation network.
w_weight_decay : float
The weight decay the search and the evaluation network parameters.
alpha_lr : float
Learning rate of the architecture parameters.
alpha_weight_decay : float
The weight decay the architecture parameters.
nasnet_lr : float
Learning rate of the evaluation network parameters.
local_rank : int
The number of thread.
share_module : bool
``True`` if sharing the stem and auxiliary heads, else not sharing these modules.
"""
if logger is None: if logger is None:
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
train_loader, valid_loader = loaders train_loader, valid_loader = loaders
......
...@@ -22,12 +22,21 @@ INPUT_CHOICE = "input_choice" ...@@ -22,12 +22,21 @@ INPUT_CHOICE = "input_choice"
def get_and_apply_next_architecture(model): def get_and_apply_next_architecture(model):
""" """
Wrapper of ClassicMutator to make it more meaningful, Wrapper of :class:`~nni.nas.pytorch.classic_nas.mutator.ClassicMutator` to make it more meaningful,
similar to ```get_next_parameter``` for HPO. similar to ``get_next_parameter`` for HPO.
Tt will generate search space based on ``model``.
If env ``NNI_GEN_SEARCH_SPACE`` exists, this is in dry run mode for
generating search space for the experiment.
If not, there are still two mode, one is nni experiment mode where users
use ``nnictl`` to start an experiment. The other is standalone mode
where users directly run the trial command, this mode chooses the first
one(s) for each LayerChoice and InputChoice.
Parameters Parameters
---------- ----------
model : pytorch model model : nn.Module
user's model with search space (e.g., LayerChoice, InputChoice) embedded in it User's model with search space (e.g., LayerChoice, InputChoice) embedded in it.
""" """
ClassicMutator(model) ClassicMutator(model)
...@@ -36,23 +45,15 @@ class ClassicMutator(Mutator): ...@@ -36,23 +45,15 @@ class ClassicMutator(Mutator):
""" """
This mutator is to apply the architecture chosen from tuner. This mutator is to apply the architecture chosen from tuner.
It implements the forward function of LayerChoice and InputChoice, It implements the forward function of LayerChoice and InputChoice,
to only activate the chosen ones to only activate the chosen ones.
Parameters
----------
model : nn.Module
User's model with search space (e.g., LayerChoice, InputChoice) embedded in it.
""" """
def __init__(self, model): def __init__(self, model):
"""
Generate search space based on ```model```.
If env ```NNI_GEN_SEARCH_SPACE``` exists, this is in dry run mode for
generating search space for the experiment.
If not, there are still two mode, one is nni experiment mode where users
use ```nnictl``` to start an experiment. The other is standalone mode
where users directly run the trial command, this mode chooses the first
one(s) for each LayerChoice and InputChoice.
Parameters
----------
model : PyTorch model
user's model with search space (e.g., LayerChoice, InputChoice) embedded in it
"""
super(ClassicMutator, self).__init__(model) super(ClassicMutator, self).__init__(model)
self._chosen_arch = {} self._chosen_arch = {}
self._search_space = self._generate_search_space() self._search_space = self._generate_search_space()
...@@ -67,6 +68,13 @@ class ClassicMutator(Mutator): ...@@ -67,6 +68,13 @@ class ClassicMutator(Mutator):
else: else:
# get chosen arch from tuner # get chosen arch from tuner
self._chosen_arch = nni.get_next_parameter() self._chosen_arch = nni.get_next_parameter()
if self._chosen_arch is None:
if trial_env_vars.NNI_PLATFORM == "unittest":
# happens if NNI_PLATFORM is intentionally set, e.g., in UT
logger.warning("`NNI_PLATFORM` is set but `param` is None. Falling back to standalone mode.")
self._chosen_arch = self._standalone_generate_chosen()
else:
raise RuntimeError("Chosen architecture is None. This may be a platform error.")
self.reset() self.reset()
def _sample_layer_choice(self, mutable, idx, value, search_space_item): def _sample_layer_choice(self, mutable, idx, value, search_space_item):
...@@ -114,9 +122,15 @@ class ClassicMutator(Mutator): ...@@ -114,9 +122,15 @@ class ClassicMutator(Mutator):
return torch.tensor(multihot_list, dtype=torch.bool) # pylint: disable=not-callable return torch.tensor(multihot_list, dtype=torch.bool) # pylint: disable=not-callable
def sample_search(self): def sample_search(self):
"""
See :meth:`sample_final`.
"""
return self.sample_final() return self.sample_final()
def sample_final(self): def sample_final(self):
"""
Convert the chosen arch and apply it on model.
"""
assert set(self._chosen_arch.keys()) == set(self._search_space.keys()), \ assert set(self._chosen_arch.keys()) == set(self._search_space.keys()), \
"Unmatched keys, expected keys '{}' from search space, found '{}'.".format(self._search_space.keys(), "Unmatched keys, expected keys '{}' from search space, found '{}'.".format(self._search_space.keys(),
self._chosen_arch.keys()) self._chosen_arch.keys())
...@@ -162,6 +176,8 @@ class ClassicMutator(Mutator): ...@@ -162,6 +176,8 @@ class ClassicMutator(Mutator):
elif val["_type"] == INPUT_CHOICE: elif val["_type"] == INPUT_CHOICE:
choices = val["_value"]["candidates"] choices = val["_value"]["candidates"]
n_chosen = val["_value"]["n_chosen"] n_chosen = val["_value"]["n_chosen"]
if n_chosen is None:
n_chosen = len(choices)
chosen_arch[key] = {"_value": choices[:n_chosen], "_idx": list(range(n_chosen))} chosen_arch[key] = {"_value": choices[:n_chosen], "_idx": list(range(n_chosen))}
else: else:
raise ValueError("Unknown key '%s' and value '%s'." % (key, val)) raise ValueError("Unknown key '%s' and value '%s'." % (key, val))
......
...@@ -63,18 +63,23 @@ class DartsMutator(Mutator): ...@@ -63,18 +63,23 @@ class DartsMutator(Mutator):
edges_max[mutable.key] = max_val edges_max[mutable.key] = max_val
result[mutable.key] = F.one_hot(index, num_classes=mutable.length).view(-1).bool() result[mutable.key] = F.one_hot(index, num_classes=mutable.length).view(-1).bool()
for mutable in self.mutables: for mutable in self.mutables:
if isinstance(mutable, InputChoice) and mutable.n_chosen is not None: if isinstance(mutable, InputChoice):
weights = [] if mutable.n_chosen is not None:
for src_key in mutable.choose_from: weights = []
if src_key not in edges_max: for src_key in mutable.choose_from:
_logger.warning("InputChoice.NO_KEY in '%s' is weighted 0 when selecting inputs.", mutable.key) if src_key not in edges_max:
weights.append(edges_max.get(src_key, 0.)) _logger.warning("InputChoice.NO_KEY in '%s' is weighted 0 when selecting inputs.", mutable.key)
weights = torch.tensor(weights) # pylint: disable=not-callable weights.append(edges_max.get(src_key, 0.))
_, topk_edge_indices = torch.topk(weights, mutable.n_chosen) weights = torch.tensor(weights) # pylint: disable=not-callable
selected_multihot = [] _, topk_edge_indices = torch.topk(weights, mutable.n_chosen)
for i, src_key in enumerate(mutable.choose_from): selected_multihot = []
if i not in topk_edge_indices and src_key in result: for i, src_key in enumerate(mutable.choose_from):
result[src_key] = torch.zeros_like(result[src_key]) # clear this choice to optimize calc graph if i not in topk_edge_indices and src_key in result:
selected_multihot.append(i in topk_edge_indices) # If an edge is never selected, there is no need to calculate any op on this edge.
result[mutable.key] = torch.tensor(selected_multihot, dtype=torch.bool, device=self.device()) # pylint: disable=not-callable # This is to eliminate redundant calculation.
result[src_key] = torch.zeros_like(result[src_key])
selected_multihot.append(i in topk_edge_indices)
result[mutable.key] = torch.tensor(selected_multihot, dtype=torch.bool, device=self.device()) # pylint: disable=not-callable
else:
result[mutable.key] = torch.ones(mutable.n_candidates, dtype=torch.bool, device=self.device()) # pylint: disable=not-callable
return result return result
...@@ -15,46 +15,46 @@ logger = logging.getLogger(__name__) ...@@ -15,46 +15,46 @@ logger = logging.getLogger(__name__)
class DartsTrainer(Trainer): class DartsTrainer(Trainer):
"""
DARTS trainer.
Parameters
----------
model : nn.Module
PyTorch model to be trained.
loss : callable
Receives logits and ground truth label, return a loss tensor.
metrics : callable
Receives logits and ground truth label, return a dict of metrics.
optimizer : Optimizer
The optimizer used for optimizing the model.
num_epochs : int
Number of epochs planned for training.
dataset_train : Dataset
Dataset for training. Will be split for training weights and architecture weights.
dataset_valid : Dataset
Dataset for testing.
mutator : DartsMutator
Use in case of customizing your own DartsMutator. By default will instantiate a DartsMutator.
batch_size : int
Batch size.
workers : int
Workers for data loading.
device : torch.device
``torch.device("cpu")`` or ``torch.device("cuda")``.
log_frequency : int
Step count per logging.
callbacks : list of Callback
list of callbacks to trigger at events.
arc_learning_rate : float
Learning rate of architecture parameters.
unrolled : float
``True`` if using second order optimization, else first order optimization.
"""
def __init__(self, model, loss, metrics, def __init__(self, model, loss, metrics,
optimizer, num_epochs, dataset_train, dataset_valid, optimizer, num_epochs, dataset_train, dataset_valid,
mutator=None, batch_size=64, workers=4, device=None, log_frequency=None, mutator=None, batch_size=64, workers=4, device=None, log_frequency=None,
callbacks=None, arc_learning_rate=3.0E-4, unrolled=False): callbacks=None, arc_learning_rate=3.0E-4, unrolled=False):
"""
Initialize a DartsTrainer.
Parameters
----------
model : nn.Module
PyTorch model to be trained.
loss : callable
Receives logits and ground truth label, return a loss tensor.
metrics : callable
Receives logits and ground truth label, return a dict of metrics.
optimizer : Optimizer
The optimizer used for optimizing the model.
num_epochs : int
Number of epochs planned for training.
dataset_train : Dataset
Dataset for training. Will be split for training weights and architecture weights.
dataset_valid : Dataset
Dataset for testing.
mutator : DartsMutator
Use in case of customizing your own DartsMutator. By default will instantiate a DartsMutator.
batch_size : int
Batch size.
workers : int
Workers for data loading.
device : torch.device
``torch.device("cpu")`` or ``torch.device("cuda")``.
log_frequency : int
Step count per logging.
callbacks : list of Callback
list of callbacks to trigger at events.
arc_learning_rate : float
Learning rate of architecture parameters.
unrolled : float
``True`` if using second order optimization, else first order optimization.
"""
super().__init__(model, mutator if mutator is not None else DartsMutator(model), super().__init__(model, mutator if mutator is not None else DartsMutator(model),
loss, metrics, optimizer, num_epochs, dataset_train, dataset_valid, loss, metrics, optimizer, num_epochs, dataset_train, dataset_valid,
batch_size, workers, device, log_frequency, callbacks) batch_size, workers, device, log_frequency, callbacks)
......
...@@ -28,38 +28,38 @@ class StackedLSTMCell(nn.Module): ...@@ -28,38 +28,38 @@ class StackedLSTMCell(nn.Module):
class EnasMutator(Mutator): class EnasMutator(Mutator):
"""
A mutator that mutates the graph with RL.
Parameters
----------
model : nn.Module
PyTorch model.
lstm_size : int
Controller LSTM hidden units.
lstm_num_layers : int
Number of layers for stacked LSTM.
tanh_constant : float
Logits will be equal to ``tanh_constant * tanh(logits)``. Don't use ``tanh`` if this value is ``None``.
cell_exit_extra_step : bool
If true, RL controller will perform an extra step at the exit of each MutableScope, dump the hidden state
and mark it as the hidden state of this MutableScope. This is to align with the original implementation of paper.
skip_target : float
Target probability that skipconnect will appear.
temperature : float
Temperature constant that divides the logits.
branch_bias : float
Manual bias applied to make some operations more likely to be chosen.
Currently this is implemented with a hardcoded match rule that aligns with original repo.
If a mutable has a ``reduce`` in its key, all its op choices
that contains `conv` in their typename will receive a bias of ``+self.branch_bias`` initially; while others
receive a bias of ``-self.branch_bias``.
entropy_reduction : str
Can be one of ``sum`` and ``mean``. How the entropy of multi-input-choice is reduced.
"""
def __init__(self, model, lstm_size=64, lstm_num_layers=1, tanh_constant=1.5, cell_exit_extra_step=False, def __init__(self, model, lstm_size=64, lstm_num_layers=1, tanh_constant=1.5, cell_exit_extra_step=False,
skip_target=0.4, temperature=None, branch_bias=0.25, entropy_reduction="sum"): skip_target=0.4, temperature=None, branch_bias=0.25, entropy_reduction="sum"):
"""
Initialize a EnasMutator.
Parameters
----------
model : nn.Module
PyTorch model.
lstm_size : int
Controller LSTM hidden units.
lstm_num_layers : int
Number of layers for stacked LSTM.
tanh_constant : float
Logits will be equal to ``tanh_constant * tanh(logits)``. Don't use ``tanh`` if this value is ``None``.
cell_exit_extra_step : bool
If true, RL controller will perform an extra step at the exit of each MutableScope, dump the hidden state
and mark it as the hidden state of this MutableScope. This is to align with the original implementation of paper.
skip_target : float
Target probability that skipconnect will appear.
temperature : float
Temperature constant that divides the logits.
branch_bias : float
Manual bias applied to make some operations more likely to be chosen.
Currently this is implemented with a hardcoded match rule that aligns with original repo.
If a mutable has a ``reduce`` in its key, all its op choices
that contains `conv` in their typename will receive a bias of ``+self.branch_bias`` initially; while others
receive a bias of ``-self.branch_bias``.
entropy_reduction : str
Can be one of ``sum`` and ``mean``. How the entropy of multi-input-choice is reduced.
"""
super().__init__(model) super().__init__(model)
self.lstm_size = lstm_size self.lstm_size = lstm_size
self.lstm_num_layers = lstm_num_layers self.lstm_num_layers = lstm_num_layers
......
...@@ -16,64 +16,64 @@ logger = logging.getLogger(__name__) ...@@ -16,64 +16,64 @@ logger = logging.getLogger(__name__)
class EnasTrainer(Trainer): class EnasTrainer(Trainer):
"""
ENAS trainer.
Parameters
----------
model : nn.Module
PyTorch model to be trained.
loss : callable
Receives logits and ground truth label, return a loss tensor.
metrics : callable
Receives logits and ground truth label, return a dict of metrics.
reward_function : callable
Receives logits and ground truth label, return a tensor, which will be feeded to RL controller as reward.
optimizer : Optimizer
The optimizer used for optimizing the model.
num_epochs : int
Number of epochs planned for training.
dataset_train : Dataset
Dataset for training. Will be split for training weights and architecture weights.
dataset_valid : Dataset
Dataset for testing.
mutator : EnasMutator
Use when customizing your own mutator or a mutator with customized parameters.
batch_size : int
Batch size.
workers : int
Workers for data loading.
device : torch.device
``torch.device("cpu")`` or ``torch.device("cuda")``.
log_frequency : int
Step count per logging.
callbacks : list of Callback
list of callbacks to trigger at events.
entropy_weight : float
Weight of sample entropy loss.
skip_weight : float
Weight of skip penalty loss.
baseline_decay : float
Decay factor of baseline. New baseline will be equal to ``baseline_decay * baseline_old + reward * (1 - baseline_decay)``.
child_steps : int
How many mini-batches for model training per epoch.
mutator_lr : float
Learning rate for RL controller.
mutator_steps_aggregate : int
Number of steps that will be aggregated into one mini-batch for RL controller.
mutator_steps : int
Number of mini-batches for each epoch of RL controller learning.
aux_weight : float
Weight of auxiliary head loss. ``aux_weight * aux_loss`` will be added to total loss.
test_arc_per_epoch : int
How many architectures are chosen for direct test after each epoch.
"""
def __init__(self, model, loss, metrics, reward_function, def __init__(self, model, loss, metrics, reward_function,
optimizer, num_epochs, dataset_train, dataset_valid, optimizer, num_epochs, dataset_train, dataset_valid,
mutator=None, batch_size=64, workers=4, device=None, log_frequency=None, callbacks=None, mutator=None, batch_size=64, workers=4, device=None, log_frequency=None, callbacks=None,
entropy_weight=0.0001, skip_weight=0.8, baseline_decay=0.999, child_steps=500, entropy_weight=0.0001, skip_weight=0.8, baseline_decay=0.999, child_steps=500,
mutator_lr=0.00035, mutator_steps_aggregate=20, mutator_steps=50, aux_weight=0.4, mutator_lr=0.00035, mutator_steps_aggregate=20, mutator_steps=50, aux_weight=0.4,
test_arc_per_epoch=1): test_arc_per_epoch=1):
"""
Initialize an EnasTrainer.
Parameters
----------
model : nn.Module
PyTorch model to be trained.
loss : callable
Receives logits and ground truth label, return a loss tensor.
metrics : callable
Receives logits and ground truth label, return a dict of metrics.
reward_function : callable
Receives logits and ground truth label, return a tensor, which will be feeded to RL controller as reward.
optimizer : Optimizer
The optimizer used for optimizing the model.
num_epochs : int
Number of epochs planned for training.
dataset_train : Dataset
Dataset for training. Will be split for training weights and architecture weights.
dataset_valid : Dataset
Dataset for testing.
mutator : EnasMutator
Use when customizing your own mutator or a mutator with customized parameters.
batch_size : int
Batch size.
workers : int
Workers for data loading.
device : torch.device
``torch.device("cpu")`` or ``torch.device("cuda")``.
log_frequency : int
Step count per logging.
callbacks : list of Callback
list of callbacks to trigger at events.
entropy_weight : float
Weight of sample entropy loss.
skip_weight : float
Weight of skip penalty loss.
baseline_decay : float
Decay factor of baseline. New baseline will be equal to ``baseline_decay * baseline_old + reward * (1 - baseline_decay)``.
child_steps : int
How many mini-batches for model training per epoch.
mutator_lr : float
Learning rate for RL controller.
mutator_steps_aggregate : int
Number of steps that will be aggregated into one mini-batch for RL controller.
mutator_steps : int
Number of mini-batches for each epoch of RL controller learning.
aux_weight : float
Weight of auxiliary head loss. ``aux_weight * aux_loss`` will be added to total loss.
test_arc_per_epoch : int
How many architectures are chosen for direct test after each epoch.
"""
super().__init__(model, mutator if mutator is not None else EnasMutator(model), super().__init__(model, mutator if mutator is not None else EnasMutator(model),
loss, metrics, optimizer, num_epochs, dataset_train, dataset_valid, loss, metrics, optimizer, num_epochs, dataset_train, dataset_valid,
batch_size, workers, device, log_frequency, callbacks) batch_size, workers, device, log_frequency, callbacks)
......
...@@ -10,20 +10,20 @@ from nni.nas.pytorch.mutator import Mutator ...@@ -10,20 +10,20 @@ from nni.nas.pytorch.mutator import Mutator
class FixedArchitecture(Mutator): class FixedArchitecture(Mutator):
"""
Fixed architecture mutator that always selects a certain graph.
Parameters
----------
model : nn.Module
A mutable network.
fixed_arc : str or dict
Path to the architecture checkpoint (a string), or preloaded architecture object (a dict).
strict : bool
Force everything that appears in ``fixed_arc`` to be used at least once.
"""
def __init__(self, model, fixed_arc, strict=True): def __init__(self, model, fixed_arc, strict=True):
"""
Initialize a fixed architecture mutator.
Parameters
----------
model : nn.Module
A mutable network.
fixed_arc : str or dict
Path to the architecture checkpoint (a string), or preloaded architecture object (a dict).
strict : bool
Force everything that appears in `fixed_arc` to be used at least once.
"""
super().__init__(model) super().__init__(model)
self._fixed_arc = fixed_arc self._fixed_arc = fixed_arc
...@@ -35,9 +35,15 @@ class FixedArchitecture(Mutator): ...@@ -35,9 +35,15 @@ class FixedArchitecture(Mutator):
raise RuntimeError("Missing keys in fixed architecture: {}.".format(mutable_keys - fixed_arc_keys)) raise RuntimeError("Missing keys in fixed architecture: {}.".format(mutable_keys - fixed_arc_keys))
def sample_search(self): def sample_search(self):
"""
Always returns the fixed architecture.
"""
return self._fixed_arc return self._fixed_arc
def sample_final(self): def sample_final(self):
"""
Always returns the fixed architecture.
"""
return self._fixed_arc return self._fixed_arc
...@@ -52,24 +58,25 @@ def _encode_tensor(data): ...@@ -52,24 +58,25 @@ def _encode_tensor(data):
return data return data
def apply_fixed_architecture(model, fixed_arc_path): def apply_fixed_architecture(model, fixed_arc):
""" """
Load architecture from `fixed_arc_path` and apply to model. Load architecture from `fixed_arc` and apply to model.
Parameters Parameters
---------- ----------
model : torch.nn.Module model : torch.nn.Module
Model with mutables. Model with mutables.
fixed_arc_path : str fixed_arc : str or dict
Path to the JSON that stores the architecture. Path to the JSON that stores the architecture, or dict that stores the exported architecture.
Returns Returns
------- -------
FixedArchitecture FixedArchitecture
Mutator that is responsible for fixes the graph.
""" """
if isinstance(fixed_arc_path, str): if isinstance(fixed_arc, str):
with open(fixed_arc_path, "r") as f: with open(fixed_arc) as f:
fixed_arc = json.load(f) fixed_arc = json.load(f)
fixed_arc = _encode_tensor(fixed_arc) fixed_arc = _encode_tensor(fixed_arc)
architecture = FixedArchitecture(model, fixed_arc) architecture = FixedArchitecture(model, fixed_arc)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment