Unverified Commit 75028bd7 authored by SparkSnail's avatar SparkSnail Committed by GitHub
Browse files

Merge pull request #235 from microsoft/master

merge master
parents 1d74ae5e 2e42d1d8
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
import { TrialConfig } from "training_service/common/trialConfig";
export class DLTSTrialConfig extends TrialConfig {
public constructor(
command: string,
codeDir: string,
gpuNum: number,
public readonly image: string
) {
super(command, codeDir, gpuNum);
}
}
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
import {
TrialJobDetail,
TrialJobStatus,
TrialJobApplicationForm
} from "../../common/trainingService";
export class DLTSTrialJobDetail implements TrialJobDetail {
public startTime?: number;
public endTime?: number;
public tags?: string[];
public url?: string;
public isEarlyStopped?: boolean;
// DLTS staff
public dltsJobId?: string;
public dltsPaused: boolean = false;
public constructor (
public id: string,
public status: TrialJobStatus,
public submitTime: number,
public workingDirectory: string,
public form: TrialJobApplicationForm,
// DLTS staff
public dltsJobName: string,
) {}
}
# Copyright (c) Microsoft Corporation. # Copyright (c) Microsoft Corporation.
# Licensed under the MIT license. # Licensed under the MIT license.
from .compressor import LayerInfo, Compressor, Pruner, Quantizer from .compressor import Compressor, Pruner, Quantizer
from .pruners import * from .pruners import *
from .weight_rank_filter_pruners import * from .weight_rank_filter_pruners import *
from .activation_rank_filter_pruners import * from .activation_rank_filter_pruners import *
......
...@@ -16,7 +16,7 @@ class ActivationRankFilterPruner(Pruner): ...@@ -16,7 +16,7 @@ class ActivationRankFilterPruner(Pruner):
to achieve a preset level of network sparsity. to achieve a preset level of network sparsity.
""" """
def __init__(self, model, config_list, activation='relu', statistics_batch_num=1): def __init__(self, model, config_list, optimizer=None, activation='relu', statistics_batch_num=1):
""" """
Parameters Parameters
---------- ----------
...@@ -25,17 +25,23 @@ class ActivationRankFilterPruner(Pruner): ...@@ -25,17 +25,23 @@ class ActivationRankFilterPruner(Pruner):
config_list : list config_list : list
support key for each list item: support key for each list item:
- sparsity: percentage of convolutional filters to be pruned. - sparsity: percentage of convolutional filters to be pruned.
optimizer: torch.optim.Optimizer
Optimizer used to train model
activation : str activation : str
Activation function Activation function
statistics_batch_num : int statistics_batch_num : int
Num of batches for activation statistics Num of batches for activation statistics
""" """
super().__init__(model, config_list) super().__init__(model, config_list, optimizer)
self.register_buffer("if_calculated", torch.tensor(0)) # pylint: disable=not-callable self.set_wrappers_attribute("if_calculated", False)
self.set_wrappers_attribute("collected_activation", [])
self.statistics_batch_num = statistics_batch_num self.statistics_batch_num = statistics_batch_num
self.collected_activation = {}
self.hooks = {} def collector(module_, input_, output):
if len(module_.collected_activation) < self.statistics_batch_num:
module_.collected_activation.append(self.activation(output.detach().cpu()))
self.add_activation_collector(collector)
assert activation in ['relu', 'relu6'] assert activation in ['relu', 'relu6']
if activation == 'relu': if activation == 'relu':
self.activation = torch.nn.functional.relu self.activation = torch.nn.functional.relu
...@@ -44,33 +50,10 @@ class ActivationRankFilterPruner(Pruner): ...@@ -44,33 +50,10 @@ class ActivationRankFilterPruner(Pruner):
else: else:
self.activation = None self.activation = None
def compress(self):
"""
Compress the model, register a hook for collecting activations.
"""
if self.modules_wrapper is not None:
# already compressed
return self.bound_model
else:
self.modules_wrapper = []
modules_to_compress = self.detect_modules_to_compress()
for layer, config in modules_to_compress:
wrapper = self._wrap_modules(layer, config)
self.modules_wrapper.append(wrapper)
self.collected_activation[layer.name] = []
def _hook(module_, input_, output, name=layer.name):
if len(self.collected_activation[name]) < self.statistics_batch_num:
self.collected_activation[name].append(self.activation(output.detach().cpu()))
wrapper.module.register_forward_hook(_hook)
self._wrap_model()
return self.bound_model
def get_mask(self, base_mask, activations, num_prune): def get_mask(self, base_mask, activations, num_prune):
raise NotImplementedError('{} get_mask is not implemented'.format(self.__class__.__name__)) raise NotImplementedError('{} get_mask is not implemented'.format(self.__class__.__name__))
def calc_mask(self, layer, config, **kwargs): def calc_mask(self, wrapper, **kwargs):
""" """
Calculate the mask of given layer. Calculate the mask of given layer.
Filters with the smallest importance criterion which is calculated from the activation are masked. Filters with the smallest importance criterion which is calculated from the activation are masked.
...@@ -88,29 +71,30 @@ class ActivationRankFilterPruner(Pruner): ...@@ -88,29 +71,30 @@ class ActivationRankFilterPruner(Pruner):
dictionary for storing masks dictionary for storing masks
""" """
weight = layer.module.weight.data weight = wrapper.module.weight.data
op_type = layer.type op_type = wrapper.type
config = wrapper.config
assert 0 <= config.get('sparsity') < 1, "sparsity must in the range [0, 1)" assert 0 <= config.get('sparsity') < 1, "sparsity must in the range [0, 1)"
assert op_type in ['Conv2d'], "only support Conv2d" assert op_type in ['Conv2d'], "only support Conv2d"
assert op_type in config.get('op_types') assert op_type in config.get('op_types')
if_calculated = kwargs["if_calculated"]
if if_calculated: if wrapper.if_calculated:
return None return None
mask_weight = torch.ones(weight.size()).type_as(weight).detach() mask_weight = torch.ones(weight.size()).type_as(weight).detach()
if hasattr(layer.module, 'bias') and layer.module.bias is not None: if hasattr(wrapper.module, 'bias') and wrapper.module.bias is not None:
mask_bias = torch.ones(layer.module.bias.size()).type_as(layer.module.bias).detach() mask_bias = torch.ones(wrapper.module.bias.size()).type_as(wrapper.module.bias).detach()
else: else:
mask_bias = None mask_bias = None
mask = {'weight': mask_weight, 'bias': mask_bias} mask = {'weight_mask': mask_weight, 'bias_mask': mask_bias}
try: try:
filters = weight.size(0) filters = weight.size(0)
num_prune = int(filters * config.get('sparsity')) num_prune = int(filters * config.get('sparsity'))
if filters < 2 or num_prune < 1 or len(self.collected_activation[layer.name]) < self.statistics_batch_num: if filters < 2 or num_prune < 1 or len(wrapper.collected_activation) < self.statistics_batch_num:
return mask return mask
mask = self.get_mask(mask, self.collected_activation[layer.name], num_prune) mask = self.get_mask(mask, wrapper.collected_activation, num_prune)
finally: finally:
if len(self.collected_activation[layer.name]) == self.statistics_batch_num: if len(wrapper.collected_activation) == self.statistics_batch_num:
if_calculated.copy_(torch.tensor(1)) # pylint: disable=not-callable wrapper.if_calculated = True
return mask return mask
...@@ -123,7 +107,7 @@ class ActivationAPoZRankFilterPruner(ActivationRankFilterPruner): ...@@ -123,7 +107,7 @@ class ActivationAPoZRankFilterPruner(ActivationRankFilterPruner):
https://arxiv.org/abs/1607.03250 https://arxiv.org/abs/1607.03250
""" """
def __init__(self, model, config_list, activation='relu', statistics_batch_num=1): def __init__(self, model, config_list, optimizer=None, activation='relu', statistics_batch_num=1):
""" """
Parameters Parameters
---------- ----------
...@@ -132,12 +116,14 @@ class ActivationAPoZRankFilterPruner(ActivationRankFilterPruner): ...@@ -132,12 +116,14 @@ class ActivationAPoZRankFilterPruner(ActivationRankFilterPruner):
config_list : list config_list : list
support key for each list item: support key for each list item:
- sparsity: percentage of convolutional filters to be pruned. - sparsity: percentage of convolutional filters to be pruned.
optimizer: torch.optim.Optimizer
Optimizer used to train model
activation : str activation : str
Activation function Activation function
statistics_batch_num : int statistics_batch_num : int
Num of batches for activation statistics Num of batches for activation statistics
""" """
super().__init__(model, config_list, activation, statistics_batch_num) super().__init__(model, config_list, optimizer, activation, statistics_batch_num)
def get_mask(self, base_mask, activations, num_prune): def get_mask(self, base_mask, activations, num_prune):
""" """
...@@ -161,9 +147,9 @@ class ActivationAPoZRankFilterPruner(ActivationRankFilterPruner): ...@@ -161,9 +147,9 @@ class ActivationAPoZRankFilterPruner(ActivationRankFilterPruner):
apoz = self._calc_apoz(activations) apoz = self._calc_apoz(activations)
prune_indices = torch.argsort(apoz, descending=True)[:num_prune] prune_indices = torch.argsort(apoz, descending=True)[:num_prune]
for idx in prune_indices: for idx in prune_indices:
base_mask['weight'][idx] = 0. base_mask['weight_mask'][idx] = 0.
if base_mask['bias'] is not None: if base_mask['bias_mask'] is not None:
base_mask['bias'][idx] = 0. base_mask['bias_mask'][idx] = 0.
return base_mask return base_mask
def _calc_apoz(self, activations): def _calc_apoz(self, activations):
...@@ -195,7 +181,7 @@ class ActivationMeanRankFilterPruner(ActivationRankFilterPruner): ...@@ -195,7 +181,7 @@ class ActivationMeanRankFilterPruner(ActivationRankFilterPruner):
https://arxiv.org/abs/1611.06440 https://arxiv.org/abs/1611.06440
""" """
def __init__(self, model, config_list, activation='relu', statistics_batch_num=1): def __init__(self, model, config_list, optimizer=None, activation='relu', statistics_batch_num=1):
""" """
Parameters Parameters
---------- ----------
...@@ -204,12 +190,14 @@ class ActivationMeanRankFilterPruner(ActivationRankFilterPruner): ...@@ -204,12 +190,14 @@ class ActivationMeanRankFilterPruner(ActivationRankFilterPruner):
config_list : list config_list : list
support key for each list item: support key for each list item:
- sparsity: percentage of convolutional filters to be pruned. - sparsity: percentage of convolutional filters to be pruned.
optimizer: torch.optim.Optimizer
Optimizer used to train model
activation : str activation : str
Activation function Activation function
statistics_batch_num : int statistics_batch_num : int
Num of batches for activation statistics Num of batches for activation statistics
""" """
super().__init__(model, config_list, activation, statistics_batch_num) super().__init__(model, config_list, optimizer, activation, statistics_batch_num)
def get_mask(self, base_mask, activations, num_prune): def get_mask(self, base_mask, activations, num_prune):
""" """
...@@ -233,9 +221,9 @@ class ActivationMeanRankFilterPruner(ActivationRankFilterPruner): ...@@ -233,9 +221,9 @@ class ActivationMeanRankFilterPruner(ActivationRankFilterPruner):
mean_activation = self._cal_mean_activation(activations) mean_activation = self._cal_mean_activation(activations)
prune_indices = torch.argsort(mean_activation)[:num_prune] prune_indices = torch.argsort(mean_activation)[:num_prune]
for idx in prune_indices: for idx in prune_indices:
base_mask['weight'][idx] = 0. base_mask['weight_mask'][idx] = 0.
if base_mask['bias'] is not None: if base_mask['bias_mask'] is not None:
base_mask['bias'][idx] = 0. base_mask['bias_mask'][idx] = 0.
return base_mask return base_mask
def _cal_mean_activation(self, activations): def _cal_mean_activation(self, activations):
......
# Copyright (c) Microsoft Corporation. # Copyright (c) Microsoft Corporation.
# Licensed under the MIT license. # Licensed under the MIT license.
import types
import logging import logging
import torch import torch
from . import default_layers from . import default_layers
...@@ -20,12 +21,13 @@ def _setattr(model, name, module): ...@@ -20,12 +21,13 @@ def _setattr(model, name, module):
model = getattr(model, name) model = getattr(model, name)
setattr(model, name_list[-1], module) setattr(model, name_list[-1], module)
class Compressor: class Compressor:
""" """
Abstract base PyTorch compressor Abstract base PyTorch compressor
""" """
def __init__(self, model, config_list): def __init__(self, model, config_list, optimizer=None):
""" """
Record necessary info in class members Record necessary info in class members
...@@ -35,15 +37,27 @@ class Compressor: ...@@ -35,15 +37,27 @@ class Compressor:
the model user wants to compress the model user wants to compress
config_list : list config_list : list
the configurations that users specify for compression the configurations that users specify for compression
optimizer: pytorch optimizer
optimizer used to train the model
""" """
self.bound_model = model self.bound_model = model
self.config_list = config_list self.config_list = config_list
self.optimizer = optimizer
self.modules_to_compress = None self.modules_to_compress = None
self.modules_wrapper = None self.modules_wrapper = []
self.buffers = {}
self.is_wrapped = False self.is_wrapped = False
def detect_modules_to_compress(self): self._fwd_hook_handles = {}
self._fwd_hook_id = 0
for layer, config in self._detect_modules_to_compress():
wrapper = self._wrap_modules(layer, config)
self.modules_wrapper.append(wrapper)
self._wrap_model()
def _detect_modules_to_compress(self):
""" """
detect all modules should be compressed, and save the result in `self.modules_to_compress`. detect all modules should be compressed, and save the result in `self.modules_to_compress`.
The model will be instrumented and user should never edit it after calling this method. The model will be instrumented and user should never edit it after calling this method.
...@@ -87,26 +101,26 @@ class Compressor: ...@@ -87,26 +101,26 @@ class Compressor:
torch.nn.Module torch.nn.Module
model with specified modules compressed. model with specified modules compressed.
""" """
if self.modules_wrapper is not None:
# already compressed
return self.bound_model
else:
self.modules_wrapper = []
modules_to_compress = self.detect_modules_to_compress()
for layer, config in modules_to_compress:
wrapper = self._wrap_modules(layer, config)
self.modules_wrapper.append(wrapper)
self._wrap_model()
return self.bound_model return self.bound_model
def register_buffer(self, name, value): def set_wrappers_attribute(self, name, value):
""" """
To register buffers used in wrapped module's forward method. To register attributes used in wrapped module's forward method.
If the type of the value is Torch.tensor, then this value is registered as a buffer in wrapper,
which will be saved by model.state_dict. Otherwise, this value is just a regular variable in wrapper.
Parameters
----------
name : str
name of the variable
value: any
value of the variable
""" """
self.buffers[name] = value for wrapper in self.get_modules_wrapper():
if isinstance(value, torch.Tensor):
wrapper.register_buffer(name, value.clone())
else:
setattr(wrapper, name, value)
def get_modules_to_compress(self): def get_modules_to_compress(self):
""" """
...@@ -180,11 +194,7 @@ class Compressor: ...@@ -180,11 +194,7 @@ class Compressor:
epoch : num epoch : num
the current epoch number the current epoch number
""" """
pass
def step(self):
"""
If user want to update model every step, user can override this method
"""
def _wrap_modules(self, layer, config): def _wrap_modules(self, layer, config):
""" """
...@@ -200,6 +210,34 @@ class Compressor: ...@@ -200,6 +210,34 @@ class Compressor:
raise NotImplementedError() raise NotImplementedError()
def add_activation_collector(self, collector):
self._fwd_hook_id += 1
self._fwd_hook_handles[self._fwd_hook_id] = []
for wrapper in self.get_modules_wrapper():
handle = wrapper.register_forward_hook(collector)
self._fwd_hook_handles[self._fwd_hook_id].append(handle)
return self._fwd_hook_id
def remove_activation_collector(self, fwd_hook_id):
if fwd_hook_id not in self._fwd_hook_handles:
raise ValueError("%s is not a valid collector id" % str(fwd_hook_id))
for handle in self._fwd_hook_handles[fwd_hook_id]:
handle.remove()
del self._fwd_hook_handles[fwd_hook_id]
def patch_optimizer(self, *tasks):
def patch_step(old_step):
def new_step(_, *args, **kwargs):
# call origin optimizer step method
output = old_step(*args, **kwargs)
# calculate mask
for task in tasks:
task()
return output
return new_step
if self.optimizer is not None:
self.optimizer.step = types.MethodType(patch_step(self.optimizer.step), self.optimizer)
class PrunerModuleWrapper(torch.nn.Module): class PrunerModuleWrapper(torch.nn.Module):
def __init__(self, module, module_name, module_type, config, pruner): def __init__(self, module, module_name, module_type, config, pruner):
""" """
...@@ -226,7 +264,6 @@ class PrunerModuleWrapper(torch.nn.Module): ...@@ -226,7 +264,6 @@ class PrunerModuleWrapper(torch.nn.Module):
# config and pruner # config and pruner
self.config = config self.config = config
self.pruner = pruner self.pruner = pruner
self.registered_buffers = []
# register buffer for mask # register buffer for mask
self.register_buffer("weight_mask", torch.ones(self.module.weight.shape)) self.register_buffer("weight_mask", torch.ones(self.module.weight.shape))
...@@ -234,29 +271,11 @@ class PrunerModuleWrapper(torch.nn.Module): ...@@ -234,29 +271,11 @@ class PrunerModuleWrapper(torch.nn.Module):
self.register_buffer("bias_mask", torch.ones(self.module.bias.shape)) self.register_buffer("bias_mask", torch.ones(self.module.bias.shape))
else: else:
self.register_buffer("bias_mask", None) self.register_buffer("bias_mask", None)
self.registered_buffers.append('weight_mask')
self.registered_buffers.append('bias_mask')
# register user specified buffer
for name in self.pruner.buffers:
self.register_buffer(name, self.pruner.buffers[name].clone())
self.registered_buffers.append(name)
def get_registered_buffers(self):
buffers = {}
for name in self.registered_buffers:
buffers[name] = getattr(self, name)
return buffers
def forward(self, *inputs): def forward(self, *inputs):
mask = self.pruner.calc_mask(LayerInfo(self.name, self.module), self.config, **self.get_registered_buffers()) # apply mask to weight, bias
if mask is not None:
self.weight_mask.copy_(mask['weight'])
# apply mask to weight
self.module.weight.data = self.module.weight.data.mul_(self.weight_mask) self.module.weight.data = self.module.weight.data.mul_(self.weight_mask)
# apply mask to bias
if hasattr(self.module, 'bias') and self.module.bias is not None: if hasattr(self.module, 'bias') and self.module.bias is not None:
if mask is not None and 'bias' in mask:
self.bias_mask.copy_(mask['bias'])
self.module.bias.data = self.module.bias.data.mul_(self.bias_mask) self.module.bias.data = self.module.bias.data.mul_(self.bias_mask)
return self.module(*inputs) return self.module(*inputs)
...@@ -272,10 +291,24 @@ class Pruner(Compressor): ...@@ -272,10 +291,24 @@ class Pruner(Compressor):
""" """
def __init__(self, model, config_list): def __init__(self, model, config_list, optimizer=None):
super().__init__(model, config_list) super().__init__(model, config_list, optimizer)
if optimizer is not None:
self.patch_optimizer(self.update_mask)
def compress(self):
self.update_mask()
return self.bound_model
def update_mask(self):
for wrapper in self.get_modules_wrapper():
masks = self.calc_mask(wrapper)
if masks is not None:
for k in masks:
assert hasattr(wrapper, k), "there is no attribute '%s' in wrapper" % k
setattr(wrapper, k, masks[k])
def calc_mask(self, layer, config, **kwargs): def calc_mask(self, wrapper, **kwargs):
""" """
Pruners should overload this method to provide mask for weight tensors. Pruners should overload this method to provide mask for weight tensors.
The mask must have the same shape and type comparing to the weight. The mask must have the same shape and type comparing to the weight.
...@@ -284,10 +317,8 @@ class Pruner(Compressor): ...@@ -284,10 +317,8 @@ class Pruner(Compressor):
Parameters Parameters
---------- ----------
layer : LayerInfo wrapper : Module
calculate mask for `layer`'s weight calculate mask for `wrapper.module`'s weight
config : dict
the configuration for generating the mask
""" """
raise NotImplementedError("Pruners must overload calc_mask()") raise NotImplementedError("Pruners must overload calc_mask()")
...@@ -327,8 +358,6 @@ class Pruner(Compressor): ...@@ -327,8 +358,6 @@ class Pruner(Compressor):
device of the model, used to place the dummy input tensor for exporting onnx file. device of the model, used to place the dummy input tensor for exporting onnx file.
the tensor is placed on cpu if ```device``` is None the tensor is placed on cpu if ```device``` is None
""" """
# if self.detect_modules_to_compress() and not self.mask_dict:
# _logger.warning('You may not use self.mask_dict in base Pruner class to record masks')
assert model_path is not None, 'model_path must be specified' assert model_path is not None, 'model_path must be specified'
mask_dict = {} mask_dict = {}
self._unwrap_model() # used for generating correct state_dict name without wrapper state self._unwrap_model() # used for generating correct state_dict name without wrapper state
...@@ -404,7 +433,6 @@ class QuantizerModuleWrapper(torch.nn.Module): ...@@ -404,7 +433,6 @@ class QuantizerModuleWrapper(torch.nn.Module):
# config and pruner # config and pruner
self.config = config self.config = config
self.quantizer = quantizer self.quantizer = quantizer
self.registered_buffers = []
# register buffer and parameter # register buffer and parameter
# old_weight is used to store origin weight and weight is used to store quantized weight # old_weight is used to store origin weight and weight is used to store quantized weight
...@@ -418,35 +446,18 @@ class QuantizerModuleWrapper(torch.nn.Module): ...@@ -418,35 +446,18 @@ class QuantizerModuleWrapper(torch.nn.Module):
delattr(self.module, 'weight') delattr(self.module, 'weight')
self.module.register_buffer('weight', self.module.old_weight) self.module.register_buffer('weight', self.module.old_weight)
# register user specified buffer
for name in self.quantizer.buffers:
self.register_buffer(name, self.quantizer.buffers[name].clone())
self.registered_buffers.append(name)
def get_registered_buffers(self):
buffers = {}
for name in self.registered_buffers:
buffers[name] = getattr(self, name)
return buffers
def forward(self, *inputs): def forward(self, *inputs):
if 'input' in self.config['quant_types']: if 'input' in self.config['quant_types']:
inputs = self.quantizer.quant_grad.apply( inputs = self.quantizer.quant_grad.apply(
inputs, inputs,
QuantType.QUANT_INPUT, QuantType.QUANT_INPUT,
self.quantizer.quantize_input, self)
self.config,
LayerInfo(self.name, self.module),
**self.get_registered_buffers())
if 'weight' in self.config['quant_types'] and _check_weight(self.module): if 'weight' in self.config['quant_types'] and _check_weight(self.module):
new_weight = self.quantizer.quant_grad.apply( new_weight = self.quantizer.quant_grad.apply(
self.module.old_weight, self.module.old_weight,
QuantType.QUANT_WEIGHT, QuantType.QUANT_WEIGHT,
self.quantizer.quantize_weight, self)
self.config,
LayerInfo(self.name, self.module),
**self.get_registered_buffers())
self.module.weight = new_weight self.module.weight = new_weight
result = self.module(*inputs) result = self.module(*inputs)
else: else:
...@@ -456,10 +467,7 @@ class QuantizerModuleWrapper(torch.nn.Module): ...@@ -456,10 +467,7 @@ class QuantizerModuleWrapper(torch.nn.Module):
result = self.quantizer.quant_grad.apply( result = self.quantizer.quant_grad.apply(
result, result,
QuantType.QUANT_OUTPUT, QuantType.QUANT_OUTPUT,
self.quantizer.quantize_output, self)
self.config,
LayerInfo(self.name, self.module),
**self.get_registered_buffers())
return result return result
class Quantizer(Compressor): class Quantizer(Compressor):
...@@ -467,11 +475,18 @@ class Quantizer(Compressor): ...@@ -467,11 +475,18 @@ class Quantizer(Compressor):
Base quantizer for pytorch quantizer Base quantizer for pytorch quantizer
""" """
def __init__(self, model, config_list): def __init__(self, model, config_list, optimizer=None):
super().__init__(model, config_list) super().__init__(model, config_list, optimizer)
self.quant_grad = QuantGrad self.quant_grad = QuantGrad
if self.optimizer is not None:
self.patch_optimizer(self.step_with_optimizer)
for wrapper in self.get_modules_wrapper():
if 'weight' in wrapper.config['quant_types']:
# old_weight is registered to keep track of weight before quantization
# and it is trainable, therefore, it should be added to optimizer.
self.optimizer.add_param_group({"params": wrapper.module.old_weight})
def quantize_weight(self, weight, config, op, op_type, op_name): def quantize_weight(self, weight, wrapper, **kwargs):
""" """
quantize should overload this method to quantize weight. quantize should overload this method to quantize weight.
This method is effectively hooked to :meth:`forward` of the model. This method is effectively hooked to :meth:`forward` of the model.
...@@ -479,12 +494,12 @@ class Quantizer(Compressor): ...@@ -479,12 +494,12 @@ class Quantizer(Compressor):
---------- ----------
weight : Tensor weight : Tensor
weight that needs to be quantized weight that needs to be quantized
config : dict wrapper : QuantizerModuleWrapper
the configuration for weight quantization the wrapper for origin module
""" """
raise NotImplementedError('Quantizer must overload quantize_weight()') raise NotImplementedError('Quantizer must overload quantize_weight()')
def quantize_output(self, output, config, op, op_type, op_name): def quantize_output(self, output, wrapper, **kwargs):
""" """
quantize should overload this method to quantize output. quantize should overload this method to quantize output.
This method is effectively hooked to :meth:`forward` of the model. This method is effectively hooked to :meth:`forward` of the model.
...@@ -492,12 +507,12 @@ class Quantizer(Compressor): ...@@ -492,12 +507,12 @@ class Quantizer(Compressor):
---------- ----------
output : Tensor output : Tensor
output that needs to be quantized output that needs to be quantized
config : dict wrapper : QuantizerModuleWrapper
the configuration for output quantization the wrapper for origin module
""" """
raise NotImplementedError('Quantizer must overload quantize_output()') raise NotImplementedError('Quantizer must overload quantize_output()')
def quantize_input(self, *inputs, config, op, op_type, op_name): def quantize_input(self, *inputs, wrapper, **kwargs):
""" """
quantize should overload this method to quantize input. quantize should overload this method to quantize input.
This method is effectively hooked to :meth:`forward` of the model. This method is effectively hooked to :meth:`forward` of the model.
...@@ -505,8 +520,8 @@ class Quantizer(Compressor): ...@@ -505,8 +520,8 @@ class Quantizer(Compressor):
---------- ----------
inputs : Tensor inputs : Tensor
inputs that needs to be quantized inputs that needs to be quantized
config : dict wrapper : QuantizerModuleWrapper
the configuration for inputs quantization the wrapper for origin module
""" """
raise NotImplementedError('Quantizer must overload quantize_input()') raise NotImplementedError('Quantizer must overload quantize_input()')
...@@ -532,6 +547,9 @@ class Quantizer(Compressor): ...@@ -532,6 +547,9 @@ class Quantizer(Compressor):
return QuantizerModuleWrapper(layer.module, layer.name, layer.type, config, self) return QuantizerModuleWrapper(layer.module, layer.name, layer.type, config, self)
def step_with_optimizer(self):
pass
class QuantType: class QuantType:
""" """
Enum class for quantization type. Enum class for quantization type.
...@@ -540,6 +558,7 @@ class QuantType: ...@@ -540,6 +558,7 @@ class QuantType:
QUANT_WEIGHT = 1 QUANT_WEIGHT = 1
QUANT_OUTPUT = 2 QUANT_OUTPUT = 2
class QuantGrad(torch.autograd.Function): class QuantGrad(torch.autograd.Function):
""" """
Base class for overriding backward function of quantization operation. Base class for overriding backward function of quantization operation.
...@@ -566,15 +585,22 @@ class QuantGrad(torch.autograd.Function): ...@@ -566,15 +585,22 @@ class QuantGrad(torch.autograd.Function):
return grad_output return grad_output
@staticmethod @staticmethod
def forward(ctx, tensor, quant_type, quant_func, config, layer, **kwargs): def forward(ctx, tensor, quant_type, wrapper, **kwargs):
ctx.save_for_backward(tensor, torch.Tensor([quant_type])) ctx.save_for_backward(tensor, torch.Tensor([quant_type]))
return quant_func(tensor, config, op=layer.module, op_type=layer.type, op_name=layer.name, **kwargs) if quant_type == QuantType.QUANT_INPUT:
return wrapper.quantizer.quantize_input(tensor, wrapper, **kwargs)
elif quant_type == QuantType.QUANT_WEIGHT:
return wrapper.quantizer.quantize_weight(tensor, wrapper, **kwargs)
elif quant_type == QuantType.QUANT_OUTPUT:
return wrapper.quantizer.quantize_output(tensor, wrapper, **kwargs)
else:
raise ValueError("unrecognized QuantType.")
@classmethod @classmethod
def backward(cls, ctx, grad_output): def backward(cls, ctx, grad_output):
tensor, quant_type = ctx.saved_variables tensor, quant_type = ctx.saved_variables
output = cls.quant_backward(tensor, grad_output, quant_type) output = cls.quant_backward(tensor, grad_output, quant_type)
return output, None, None, None, None, None return output, None, None, None
def _check_weight(module): def _check_weight(module):
try: try:
......
...@@ -16,7 +16,7 @@ class LevelPruner(Pruner): ...@@ -16,7 +16,7 @@ class LevelPruner(Pruner):
Prune to an exact pruning level specification Prune to an exact pruning level specification
""" """
def __init__(self, model, config_list): def __init__(self, model, config_list, optimizer=None):
""" """
Parameters Parameters
---------- ----------
...@@ -24,38 +24,39 @@ class LevelPruner(Pruner): ...@@ -24,38 +24,39 @@ class LevelPruner(Pruner):
Model to be pruned Model to be pruned
config_list : list config_list : list
List on pruning configs List on pruning configs
optimizer: torch.optim.Optimizer
Optimizer used to train model
""" """
super().__init__(model, config_list) super().__init__(model, config_list, optimizer)
self.register_buffer("if_calculated", torch.tensor(0)) # pylint: disable=not-callable self.set_wrappers_attribute("if_calculated", False)
def calc_mask(self, layer, config, **kwargs): def calc_mask(self, wrapper, **kwargs):
""" """
Calculate the mask of given layer Calculate the mask of given layer
Parameters Parameters
---------- ----------
layer : LayerInfo wrapper : Module
the layer to instrument the compression operation the module to instrument the compression operation
config : dict
layer's pruning config
Returns Returns
------- -------
dict dict
dictionary for storing masks dictionary for storing masks
""" """
weight = layer.module.weight.data config = wrapper.config
if_calculated = kwargs["if_calculated"] weight = wrapper.module.weight.data
if not if_calculated: if not wrapper.if_calculated:
w_abs = weight.abs() w_abs = weight.abs()
k = int(weight.numel() * config['sparsity']) k = int(weight.numel() * config['sparsity'])
if k == 0: if k == 0:
return torch.ones(weight.shape).type_as(weight) return torch.ones(weight.shape).type_as(weight)
threshold = torch.topk(w_abs.view(-1), k, largest=False)[0].max() threshold = torch.topk(w_abs.view(-1), k, largest=False)[0].max()
mask_weight = torch.gt(w_abs, threshold).type_as(weight) mask_weight = torch.gt(w_abs, threshold).type_as(weight)
mask = {'weight': mask_weight} mask = {'weight_mask': mask_weight}
if_calculated.copy_(torch.tensor(1)) # pylint: disable=not-callable wrapper.if_calculated = True
return mask return mask
else: else:
return None return None
...@@ -71,7 +72,7 @@ class AGP_Pruner(Pruner): ...@@ -71,7 +72,7 @@ class AGP_Pruner(Pruner):
https://arxiv.org/pdf/1710.01878.pdf https://arxiv.org/pdf/1710.01878.pdf
""" """
def __init__(self, model, config_list): def __init__(self, model, config_list, optimizer):
""" """
Parameters Parameters
---------- ----------
...@@ -79,50 +80,51 @@ class AGP_Pruner(Pruner): ...@@ -79,50 +80,51 @@ class AGP_Pruner(Pruner):
Model to be pruned Model to be pruned
config_list : list config_list : list
List on pruning configs List on pruning configs
optimizer: torch.optim.Optimizer
Optimizer used to train model
""" """
super().__init__(model, config_list) super().__init__(model, config_list, optimizer)
assert isinstance(optimizer, torch.optim.Optimizer), "AGP pruner is an iterative pruner, please pass optimizer of the model to it"
self.now_epoch = 0 self.now_epoch = 0
self.register_buffer("if_calculated", torch.tensor(0)) # pylint: disable=not-callable self.set_wrappers_attribute("if_calculated", False)
def calc_mask(self, layer, config, **kwargs): def calc_mask(self, wrapper, **kwargs):
""" """
Calculate the mask of given layer. Calculate the mask of given layer.
Scale factors with the smallest absolute value in the BN layer are masked. Scale factors with the smallest absolute value in the BN layer are masked.
Parameters Parameters
---------- ----------
layer : LayerInfo wrapper : Module
the layer to instrument the compression operation the layer to instrument the compression operation
config : dict
layer's pruning config
kwargs: dict
buffers registered in __init__ function
Returns Returns
------- -------
dict dict
dictionary for storing masks dictionary for storing masks
""" """
weight = layer.module.weight.data config = wrapper.config
weight = wrapper.module.weight.data
start_epoch = config.get('start_epoch', 0) start_epoch = config.get('start_epoch', 0)
freq = config.get('frequency', 1) freq = config.get('frequency', 1)
if_calculated = kwargs["if_calculated"] if wrapper.if_calculated:
if if_calculated:
return None return None
if not (self.now_epoch >= start_epoch and (self.now_epoch - start_epoch) % freq == 0): if not (self.now_epoch >= start_epoch and (self.now_epoch - start_epoch) % freq == 0):
return None return None
mask = {'weight': kwargs['weight_mask'] if 'weight_mask' in kwargs else torch.ones(weight.shape).type_as(weight)} mask = {'weight_mask': wrapper.weight_mask}
target_sparsity = self.compute_target_sparsity(config) target_sparsity = self.compute_target_sparsity(config)
k = int(weight.numel() * target_sparsity) k = int(weight.numel() * target_sparsity)
if k == 0 or target_sparsity >= 1 or target_sparsity <= 0: if k == 0 or target_sparsity >= 1 or target_sparsity <= 0:
return mask return mask
# if we want to generate new mask, we should update weigth first # if we want to generate new mask, we should update weigth first
w_abs = weight.abs() * mask['weight'] w_abs = weight.abs() * mask['weight_mask']
threshold = torch.topk(w_abs.view(-1), k, largest=False)[0].max() threshold = torch.topk(w_abs.view(-1), k, largest=False)[0].max()
new_mask = {'weight': torch.gt(w_abs, threshold).type_as(weight)} new_mask = {'weight_mask': torch.gt(w_abs, threshold).type_as(weight)}
if_calculated.copy_(torch.tensor(1)) # pylint: disable=not-callable wrapper.if_calculated = True
return new_mask return new_mask
...@@ -180,62 +182,64 @@ class SlimPruner(Pruner): ...@@ -180,62 +182,64 @@ class SlimPruner(Pruner):
https://arxiv.org/pdf/1708.06519.pdf https://arxiv.org/pdf/1708.06519.pdf
""" """
def __init__(self, model, config_list): def __init__(self, model, config_list, optimizer=None):
""" """
Parameters Parameters
---------- ----------
model : torch.nn.module
Model to be pruned
config_list : list config_list : list
support key for each list item: support key for each list item:
- sparsity: percentage of convolutional filters to be pruned. - sparsity: percentage of convolutional filters to be pruned.
optimizer: torch.optim.Optimizer
Optimizer used to train model
""" """
super().__init__(model, config_list) super().__init__(model, config_list, optimizer)
weight_list = [] weight_list = []
if len(config_list) > 1: if len(config_list) > 1:
logger.warning('Slim pruner only supports 1 configuration') logger.warning('Slim pruner only supports 1 configuration')
config = config_list[0] config = config_list[0]
for (layer, config) in self.detect_modules_to_compress(): for (layer, config) in self.get_modules_to_compress():
assert layer.type == 'BatchNorm2d', 'SlimPruner only supports 2d batch normalization layer pruning' assert layer.type == 'BatchNorm2d', 'SlimPruner only supports 2d batch normalization layer pruning'
weight_list.append(layer.module.weight.data.abs().clone()) weight_list.append(layer.module.weight.data.abs().clone())
all_bn_weights = torch.cat(weight_list) all_bn_weights = torch.cat(weight_list)
k = int(all_bn_weights.shape[0] * config['sparsity']) k = int(all_bn_weights.shape[0] * config['sparsity'])
self.global_threshold = torch.topk(all_bn_weights.view(-1), k, largest=False)[0].max() self.global_threshold = torch.topk(all_bn_weights.view(-1), k, largest=False)[0].max()
self.register_buffer("if_calculated", torch.tensor(0)) # pylint: disable=not-callable self.set_wrappers_attribute("if_calculated", False)
def calc_mask(self, layer, config, **kwargs): def calc_mask(self, wrapper, **kwargs):
""" """
Calculate the mask of given layer. Calculate the mask of given layer.
Scale factors with the smallest absolute value in the BN layer are masked. Scale factors with the smallest absolute value in the BN layer are masked.
Parameters Parameters
---------- ----------
layer : LayerInfo wrapper : Module
the layer to instrument the compression operation the layer to instrument the compression operation
config : dict
layer's pruning config
kwargs: dict
buffers registered in __init__ function
Returns Returns
------- -------
dict dict
dictionary for storing masks dictionary for storing masks
""" """
weight = layer.module.weight.data config = wrapper.config
op_type = layer.type weight = wrapper.module.weight.data
if_calculated = kwargs["if_calculated"] op_type = wrapper.type
assert op_type == 'BatchNorm2d', 'SlimPruner only supports 2d batch normalization layer pruning' assert op_type == 'BatchNorm2d', 'SlimPruner only supports 2d batch normalization layer pruning'
if if_calculated: if wrapper.if_calculated:
return None return None
base_mask = torch.ones(weight.size()).type_as(weight).detach() base_mask = torch.ones(weight.size()).type_as(weight).detach()
mask = {'weight': base_mask.detach(), 'bias': base_mask.clone().detach()} mask = {'weight_mask': base_mask.detach(), 'bias_mask': base_mask.clone().detach()}
filters = weight.size(0) filters = weight.size(0)
num_prune = int(filters * config.get('sparsity')) num_prune = int(filters * config.get('sparsity'))
if filters >= 2 and num_prune >= 1: if filters >= 2 and num_prune >= 1:
w_abs = weight.abs() w_abs = weight.abs()
mask_weight = torch.gt(w_abs, self.global_threshold).type_as(weight) mask_weight = torch.gt(w_abs, self.global_threshold).type_as(weight)
mask_bias = mask_weight.clone() mask_bias = mask_weight.clone()
mask = {'weight': mask_weight.detach(), 'bias': mask_bias.detach()} mask = {'weight_mask': mask_weight.detach(), 'bias_mask': mask_bias.detach()}
if_calculated.copy_(torch.tensor(1)) # pylint: disable=not-callable wrapper.if_calculated = True
return mask return mask
class LotteryTicketPruner(Pruner): class LotteryTicketPruner(Pruner):
...@@ -250,7 +254,7 @@ class LotteryTicketPruner(Pruner): ...@@ -250,7 +254,7 @@ class LotteryTicketPruner(Pruner):
5. Repeat step 2, 3, and 4. 5. Repeat step 2, 3, and 4.
""" """
def __init__(self, model, config_list, optimizer, lr_scheduler=None, reset_weights=True): def __init__(self, model, config_list, optimizer=None, lr_scheduler=None, reset_weights=True):
""" """
Parameters Parameters
---------- ----------
...@@ -267,7 +271,7 @@ class LotteryTicketPruner(Pruner): ...@@ -267,7 +271,7 @@ class LotteryTicketPruner(Pruner):
reset_weights : bool reset_weights : bool
Whether reset weights and optimizer at the beginning of each round. Whether reset weights and optimizer at the beginning of each round.
""" """
super().__init__(model, config_list) super().__init__(model, config_list, optimizer)
self.curr_prune_iteration = None self.curr_prune_iteration = None
self.prune_iterations = self._validate_config(config_list) self.prune_iterations = self._validate_config(config_list)
...@@ -307,20 +311,16 @@ class LotteryTicketPruner(Pruner): ...@@ -307,20 +311,16 @@ class LotteryTicketPruner(Pruner):
k = int(w_abs.numel() * curr_sparsity) k = int(w_abs.numel() * curr_sparsity)
threshold = torch.topk(w_abs.view(-1), k, largest=False).values.max() threshold = torch.topk(w_abs.view(-1), k, largest=False).values.max()
mask = torch.gt(w_abs, threshold).type_as(weight) mask = torch.gt(w_abs, threshold).type_as(weight)
return {'weight': mask} return {'weight_mask': mask}
def calc_mask(self, layer, config, **kwargs): def calc_mask(self, wrapper, **kwargs):
""" """
Generate mask for the given ``weight``. Generate mask for the given ``weight``.
Parameters Parameters
---------- ----------
layer : LayerInfo wrapper : Module
The layer to be pruned The layer to be pruned
config : dict
Pruning configurations for this weight
kwargs : dict
Auxiliary information
Returns Returns
------- -------
...@@ -355,7 +355,7 @@ class LotteryTicketPruner(Pruner): ...@@ -355,7 +355,7 @@ class LotteryTicketPruner(Pruner):
assert self.curr_prune_iteration < self.prune_iterations + 1, 'Exceed the configured prune_iterations' assert self.curr_prune_iteration < self.prune_iterations + 1, 'Exceed the configured prune_iterations'
modules_wrapper = self.get_modules_wrapper() modules_wrapper = self.get_modules_wrapper()
modules_to_compress = self.detect_modules_to_compress() modules_to_compress = self.get_modules_to_compress()
for layer, config in modules_to_compress: for layer, config in modules_to_compress:
module_wrapper = None module_wrapper = None
for wrapper in modules_wrapper: for wrapper in modules_wrapper:
...@@ -367,7 +367,7 @@ class LotteryTicketPruner(Pruner): ...@@ -367,7 +367,7 @@ class LotteryTicketPruner(Pruner):
sparsity = config.get('sparsity') sparsity = config.get('sparsity')
mask = self._calc_mask(layer.module.weight.data, sparsity, module_wrapper.weight_mask) mask = self._calc_mask(layer.module.weight.data, sparsity, module_wrapper.weight_mask)
# TODO: directly use weight_mask is not good # TODO: directly use weight_mask is not good
module_wrapper.weight_mask.copy_(mask['weight']) module_wrapper.weight_mask = mask['weight_mask']
# there is no mask for bias # there is no mask for bias
# reinit weights back to original after new masks are generated # reinit weights back to original after new masks are generated
......
...@@ -13,14 +13,14 @@ logger = logging.getLogger(__name__) ...@@ -13,14 +13,14 @@ logger = logging.getLogger(__name__)
class NaiveQuantizer(Quantizer): class NaiveQuantizer(Quantizer):
"""quantize weight to 8 bits """quantize weight to 8 bits
""" """
def __init__(self, model, config_list): def __init__(self, model, config_list, optimizer=None):
super().__init__(model, config_list) super().__init__(model, config_list, optimizer)
self.layer_scale = {} self.layer_scale = {}
def quantize_weight(self, weight, config, op_name, **kwargs): def quantize_weight(self, weight, wrapper, **kwargs):
new_scale = weight.abs().max() / 127 new_scale = weight.abs().max() / 127
scale = max(self.layer_scale.get(op_name, 0), new_scale) scale = max(self.layer_scale.get(wrapper.name, 0), new_scale)
self.layer_scale[op_name] = scale self.layer_scale[wrapper.name] = scale
orig_type = weight.type() # TODO: user layer orig_type = weight.type() # TODO: user layer
return weight.div(scale).type(torch.int8).type(orig_type).mul(scale) return weight.div(scale).type(torch.int8).type(orig_type).mul(scale)
...@@ -104,7 +104,7 @@ class QAT_Quantizer(Quantizer): ...@@ -104,7 +104,7 @@ class QAT_Quantizer(Quantizer):
Quantization and Training of Neural Networks for Efficient Integer-Arithmetic-Only Inference Quantization and Training of Neural Networks for Efficient Integer-Arithmetic-Only Inference
http://openaccess.thecvf.com/content_cvpr_2018/papers/Jacob_Quantization_and_Training_CVPR_2018_paper.pdf http://openaccess.thecvf.com/content_cvpr_2018/papers/Jacob_Quantization_and_Training_CVPR_2018_paper.pdf
""" """
def __init__(self, model, config_list): def __init__(self, model, config_list, optimizer=None):
""" """
Parameters Parameters
---------- ----------
...@@ -124,9 +124,9 @@ class QAT_Quantizer(Quantizer): ...@@ -124,9 +124,9 @@ class QAT_Quantizer(Quantizer):
- op_types : list of string - op_types : list of string
types of nn.module you want to apply quantization, eg. 'Conv2d' types of nn.module you want to apply quantization, eg. 'Conv2d'
""" """
super().__init__(model, config_list) super().__init__(model, config_list, optimizer)
self.steps = 1 self.steps = 1
modules_to_compress = self.detect_modules_to_compress() modules_to_compress = self.get_modules_to_compress()
for layer, config in modules_to_compress: for layer, config in modules_to_compress:
layer.module.register_buffer("zero_point", None) layer.module.register_buffer("zero_point", None)
layer.module.register_buffer("scale", None) layer.module.register_buffer("scale", None)
...@@ -181,7 +181,9 @@ class QAT_Quantizer(Quantizer): ...@@ -181,7 +181,9 @@ class QAT_Quantizer(Quantizer):
real_val = op.scale * (quantized_val - op.zero_point) real_val = op.scale * (quantized_val - op.zero_point)
return real_val return real_val
def quantize_weight(self, weight, config, op, **kwargs): def quantize_weight(self, weight, wrapper, **kwargs):
config = wrapper.config
module = wrapper.module
weight_bits = get_bits_length(config, 'weight') weight_bits = get_bits_length(config, 'weight')
quant_start_step = config.get('quant_start_step', 0) quant_start_step = config.get('quant_start_step', 0)
assert weight_bits >= 1, "quant bits length should be at least 1" assert weight_bits >= 1, "quant bits length should be at least 1"
...@@ -189,12 +191,14 @@ class QAT_Quantizer(Quantizer): ...@@ -189,12 +191,14 @@ class QAT_Quantizer(Quantizer):
if quant_start_step > self.steps: if quant_start_step > self.steps:
return weight return weight
rmin, rmax = torch.min(weight), torch.max(weight) rmin, rmax = torch.min(weight), torch.max(weight)
op.scale, op.zero_point = update_quantization_param(weight_bits, rmin, rmax) module.scale, module.zero_point = update_quantization_param(weight_bits, rmin, rmax)
out = self._quantize(weight_bits, op, weight) out = self._quantize(weight_bits, module, weight)
out = self._dequantize(op, out) out = self._dequantize(module, out)
return out return out
def quantize_output(self, output, config, op, **kwargs): def quantize_output(self, output, wrapper, **kwargs):
config = wrapper.config
module = wrapper.module
output_bits = get_bits_length(config, 'output') output_bits = get_bits_length(config, 'output')
quant_start_step = config.get('quant_start_step', 0) quant_start_step = config.get('quant_start_step', 0)
assert output_bits >= 1, "quant bits length should be at least 1" assert output_bits >= 1, "quant bits length should be at least 1"
...@@ -203,18 +207,18 @@ class QAT_Quantizer(Quantizer): ...@@ -203,18 +207,18 @@ class QAT_Quantizer(Quantizer):
return output return output
current_min, current_max = torch.min(output), torch.max(output) current_min, current_max = torch.min(output), torch.max(output)
op.tracked_min_biased, op.tracked_min = update_ema(op.tracked_min_biased, current_min, op.ema_decay, self.steps) module.tracked_min_biased, module.tracked_min = update_ema(module.tracked_min_biased, current_min, module.ema_decay, self.steps)
op.tracked_max_biased, op.tracked_max = update_ema(op.tracked_max_biased, current_max, op.ema_decay, self.steps) module.tracked_max_biased, module.tracked_max = update_ema(module.tracked_max_biased, current_max, module.ema_decay, self.steps)
op.scale, op.zero_point = update_quantization_param(output_bits, op.tracked_min, op.tracked_max) module.scale, module.zero_point = update_quantization_param(output_bits, module.tracked_min, module.tracked_max)
out = self._quantize(output_bits, op, output) out = self._quantize(output_bits, module, output)
out = self._dequantize(op, out) out = self._dequantize(module, out)
return out return out
def fold_bn(self, config, **kwargs): def fold_bn(self, config, **kwargs):
# TODO simulate folded weight # TODO simulate folded weight
pass pass
def step(self): def step_with_optimizer(self):
""" """
override `compressor` `step` method, quantization only happens after certain number of steps override `compressor` `step` method, quantization only happens after certain number of steps
""" """
...@@ -226,11 +230,11 @@ class DoReFaQuantizer(Quantizer): ...@@ -226,11 +230,11 @@ class DoReFaQuantizer(Quantizer):
Zhou et al., DoReFa-Net: Training Low Bitwidth Convolutional Neural Networks with Low Bitwidth Gradients Zhou et al., DoReFa-Net: Training Low Bitwidth Convolutional Neural Networks with Low Bitwidth Gradients
(https://arxiv.org/abs/1606.06160) (https://arxiv.org/abs/1606.06160)
""" """
def __init__(self, model, config_list): def __init__(self, model, config_list, optimizer=None):
super().__init__(model, config_list) super().__init__(model, config_list, optimizer)
def quantize_weight(self, weight, config, **kwargs): def quantize_weight(self, weight, wrapper, **kwargs):
weight_bits = get_bits_length(config, 'weight') weight_bits = get_bits_length(wrapper.config, 'weight')
out = weight.tanh() out = weight.tanh()
out = out / (2 * out.abs().max()) + 0.5 out = out / (2 * out.abs().max()) + 0.5
out = self.quantize(out, weight_bits) out = self.quantize(out, weight_bits)
...@@ -256,17 +260,17 @@ class BNNQuantizer(Quantizer): ...@@ -256,17 +260,17 @@ class BNNQuantizer(Quantizer):
Binarized Neural Networks: Training Deep Neural Networks with Weights and Activations Constrained to +1 or -1 Binarized Neural Networks: Training Deep Neural Networks with Weights and Activations Constrained to +1 or -1
(https://arxiv.org/abs/1602.02830) (https://arxiv.org/abs/1602.02830)
""" """
def __init__(self, model, config_list): def __init__(self, model, config_list, optimizer=None):
super().__init__(model, config_list) super().__init__(model, config_list, optimizer)
self.quant_grad = ClipGrad self.quant_grad = ClipGrad
def quantize_weight(self, weight, config, **kwargs): def quantize_weight(self, weight, wrapper, **kwargs):
out = torch.sign(weight) out = torch.sign(weight)
# remove zeros # remove zeros
out[out == 0] = 1 out[out == 0] = 1
return out return out
def quantize_output(self, output, config, **kwargs): def quantize_output(self, output, wrapper, **kwargs):
out = torch.sign(output) out = torch.sign(output)
# remove zeros # remove zeros
out[out == 0] = 1 out[out == 0] = 1
......
...@@ -15,7 +15,7 @@ class WeightRankFilterPruner(Pruner): ...@@ -15,7 +15,7 @@ class WeightRankFilterPruner(Pruner):
importance criterion in convolution layers to achieve a preset level of network sparsity. importance criterion in convolution layers to achieve a preset level of network sparsity.
""" """
def __init__(self, model, config_list): def __init__(self, model, config_list, optimizer=None):
""" """
Parameters Parameters
---------- ----------
...@@ -24,15 +24,17 @@ class WeightRankFilterPruner(Pruner): ...@@ -24,15 +24,17 @@ class WeightRankFilterPruner(Pruner):
config_list : list config_list : list
support key for each list item: support key for each list item:
- sparsity: percentage of convolutional filters to be pruned. - sparsity: percentage of convolutional filters to be pruned.
optimizer: torch.optim.Optimizer
Optimizer used to train model
""" """
super().__init__(model, config_list) super().__init__(model, config_list, optimizer)
self.register_buffer("if_calculated", torch.tensor(0)) # pylint: disable=not-callable self.set_wrappers_attribute("if_calculated", False)
def get_mask(self, base_mask, weight, num_prune): def get_mask(self, base_mask, weight, num_prune):
raise NotImplementedError('{} get_mask is not implemented'.format(self.__class__.__name__)) raise NotImplementedError('{} get_mask is not implemented'.format(self.__class__.__name__))
def calc_mask(self, layer, config, **kwargs): def calc_mask(self, wrapper, **kwargs):
""" """
Calculate the mask of given layer. Calculate the mask of given layer.
Filters with the smallest importance criterion of the kernel weights are masked. Filters with the smallest importance criterion of the kernel weights are masked.
...@@ -48,20 +50,21 @@ class WeightRankFilterPruner(Pruner): ...@@ -48,20 +50,21 @@ class WeightRankFilterPruner(Pruner):
dictionary for storing masks dictionary for storing masks
""" """
weight = layer.module.weight.data weight = wrapper.module.weight.data
op_type = layer.type op_type = wrapper.type
config = wrapper.config
assert 0 <= config.get('sparsity') < 1, "sparsity must in the range [0, 1)" assert 0 <= config.get('sparsity') < 1, "sparsity must in the range [0, 1)"
assert op_type in ['Conv1d', 'Conv2d'], "only support Conv1d and Conv2d" assert op_type in ['Conv1d', 'Conv2d'], "only support Conv1d and Conv2d"
assert op_type in config.get('op_types') assert op_type in config.get('op_types')
if_calculated = kwargs["if_calculated"]
if if_calculated: if wrapper.if_calculated:
return None return None
mask_weight = torch.ones(weight.size()).type_as(weight).detach() mask_weight = torch.ones(weight.size()).type_as(weight).detach()
if hasattr(layer.module, 'bias') and layer.module.bias is not None: if hasattr(wrapper.module, 'bias') and wrapper.module.bias is not None:
mask_bias = torch.ones(layer.module.bias.size()).type_as(layer.module.bias).detach() mask_bias = torch.ones(wrapper.module.bias.size()).type_as(wrapper.module.bias).detach()
else: else:
mask_bias = None mask_bias = None
mask = {'weight': mask_weight, 'bias': mask_bias} mask = {'weight_mask': mask_weight, 'bias_mask': mask_bias}
try: try:
filters = weight.size(0) filters = weight.size(0)
num_prune = int(filters * config.get('sparsity')) num_prune = int(filters * config.get('sparsity'))
...@@ -69,7 +72,7 @@ class WeightRankFilterPruner(Pruner): ...@@ -69,7 +72,7 @@ class WeightRankFilterPruner(Pruner):
return mask return mask
mask = self.get_mask(mask, weight, num_prune) mask = self.get_mask(mask, weight, num_prune)
finally: finally:
if_calculated.copy_(torch.tensor(1)) # pylint: disable=not-callable wrapper.if_calculated = True
return mask return mask
...@@ -82,7 +85,7 @@ class L1FilterPruner(WeightRankFilterPruner): ...@@ -82,7 +85,7 @@ class L1FilterPruner(WeightRankFilterPruner):
https://arxiv.org/abs/1608.08710 https://arxiv.org/abs/1608.08710
""" """
def __init__(self, model, config_list): def __init__(self, model, config_list, optimizer=None):
""" """
Parameters Parameters
---------- ----------
...@@ -91,9 +94,11 @@ class L1FilterPruner(WeightRankFilterPruner): ...@@ -91,9 +94,11 @@ class L1FilterPruner(WeightRankFilterPruner):
config_list : list config_list : list
support key for each list item: support key for each list item:
- sparsity: percentage of convolutional filters to be pruned. - sparsity: percentage of convolutional filters to be pruned.
optimizer: torch.optim.Optimizer
Optimizer used to train model
""" """
super().__init__(model, config_list) super().__init__(model, config_list, optimizer)
def get_mask(self, base_mask, weight, num_prune): def get_mask(self, base_mask, weight, num_prune):
""" """
...@@ -121,7 +126,7 @@ class L1FilterPruner(WeightRankFilterPruner): ...@@ -121,7 +126,7 @@ class L1FilterPruner(WeightRankFilterPruner):
mask_weight = torch.gt(w_abs_structured, threshold)[:, None, None, None].expand_as(weight).type_as(weight) mask_weight = torch.gt(w_abs_structured, threshold)[:, None, None, None].expand_as(weight).type_as(weight)
mask_bias = torch.gt(w_abs_structured, threshold).type_as(weight) mask_bias = torch.gt(w_abs_structured, threshold).type_as(weight)
return {'weight': mask_weight.detach(), 'bias': mask_bias.detach()} return {'weight_mask': mask_weight.detach(), 'bias_mask': mask_bias.detach()}
class L2FilterPruner(WeightRankFilterPruner): class L2FilterPruner(WeightRankFilterPruner):
...@@ -130,7 +135,7 @@ class L2FilterPruner(WeightRankFilterPruner): ...@@ -130,7 +135,7 @@ class L2FilterPruner(WeightRankFilterPruner):
smallest L2 norm of the weights. smallest L2 norm of the weights.
""" """
def __init__(self, model, config_list): def __init__(self, model, config_list, optimizer=None):
""" """
Parameters Parameters
---------- ----------
...@@ -139,9 +144,11 @@ class L2FilterPruner(WeightRankFilterPruner): ...@@ -139,9 +144,11 @@ class L2FilterPruner(WeightRankFilterPruner):
config_list : list config_list : list
support key for each list item: support key for each list item:
- sparsity: percentage of convolutional filters to be pruned. - sparsity: percentage of convolutional filters to be pruned.
optimizer: torch.optim.Optimizer
Optimizer used to train model
""" """
super().__init__(model, config_list) super().__init__(model, config_list, optimizer)
def get_mask(self, base_mask, weight, num_prune): def get_mask(self, base_mask, weight, num_prune):
""" """
...@@ -167,7 +174,7 @@ class L2FilterPruner(WeightRankFilterPruner): ...@@ -167,7 +174,7 @@ class L2FilterPruner(WeightRankFilterPruner):
mask_weight = torch.gt(w_l2_norm, threshold)[:, None, None, None].expand_as(weight).type_as(weight) mask_weight = torch.gt(w_l2_norm, threshold)[:, None, None, None].expand_as(weight).type_as(weight)
mask_bias = torch.gt(w_l2_norm, threshold).type_as(weight) mask_bias = torch.gt(w_l2_norm, threshold).type_as(weight)
return {'weight': mask_weight.detach(), 'bias': mask_bias.detach()} return {'weight_mask': mask_weight.detach(), 'bias_mask': mask_bias.detach()}
class FPGMPruner(WeightRankFilterPruner): class FPGMPruner(WeightRankFilterPruner):
...@@ -177,7 +184,7 @@ class FPGMPruner(WeightRankFilterPruner): ...@@ -177,7 +184,7 @@ class FPGMPruner(WeightRankFilterPruner):
https://arxiv.org/pdf/1811.00250.pdf https://arxiv.org/pdf/1811.00250.pdf
""" """
def __init__(self, model, config_list): def __init__(self, model, config_list, optimizer):
""" """
Parameters Parameters
---------- ----------
...@@ -186,8 +193,11 @@ class FPGMPruner(WeightRankFilterPruner): ...@@ -186,8 +193,11 @@ class FPGMPruner(WeightRankFilterPruner):
config_list: list config_list: list
support key for each list item: support key for each list item:
- sparsity: percentage of convolutional filters to be pruned. - sparsity: percentage of convolutional filters to be pruned.
optimizer: torch.optim.Optimizer
Optimizer used to train model
""" """
super().__init__(model, config_list) super().__init__(model, config_list, optimizer)
assert isinstance(optimizer, torch.optim.Optimizer), "FPGM pruner is an iterative pruner, please pass optimizer of the model to it"
def get_mask(self, base_mask, weight, num_prune): def get_mask(self, base_mask, weight, num_prune):
""" """
...@@ -208,9 +218,9 @@ class FPGMPruner(WeightRankFilterPruner): ...@@ -208,9 +218,9 @@ class FPGMPruner(WeightRankFilterPruner):
""" """
min_gm_idx = self._get_min_gm_kernel_idx(weight, num_prune) min_gm_idx = self._get_min_gm_kernel_idx(weight, num_prune)
for idx in min_gm_idx: for idx in min_gm_idx:
base_mask['weight'][idx] = 0. base_mask['weight_mask'][idx] = 0.
if base_mask['bias'] is not None: if base_mask['bias_mask'] is not None:
base_mask['bias'][idx] = 0. base_mask['bias_mask'][idx] = 0.
return base_mask return base_mask
def _get_min_gm_kernel_idx(self, weight, n): def _get_min_gm_kernel_idx(self, weight, n):
...@@ -258,4 +268,4 @@ class FPGMPruner(WeightRankFilterPruner): ...@@ -258,4 +268,4 @@ class FPGMPruner(WeightRankFilterPruner):
def update_epoch(self, epoch): def update_epoch(self, epoch):
for wrapper in self.get_modules_wrapper(): for wrapper in self.get_modules_wrapper():
wrapper.registered_buffers['if_calculated'].copy_(torch.tensor(0)) # pylint: disable=not-callable wrapper.if_calculated = False
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
import logging import logging
import datetime import datetime
from nni.assessor import Assessor, AssessResult from nni.assessor import Assessor, AssessResult
from nni.utils import extract_scalar_history
from .model_factory import CurveModel from .model_factory import CurveModel
logger = logging.getLogger('curvefitting_Assessor') logger = logging.getLogger('curvefitting_Assessor')
...@@ -91,10 +92,11 @@ class CurvefittingAssessor(Assessor): ...@@ -91,10 +92,11 @@ class CurvefittingAssessor(Assessor):
Exception Exception
unrecognize exception in curvefitting_assessor unrecognize exception in curvefitting_assessor
""" """
self.trial_history = trial_history scalar_trial_history = extract_scalar_history(trial_history)
self.trial_history = scalar_trial_history
if not self.set_best_performance: if not self.set_best_performance:
return AssessResult.Good return AssessResult.Good
curr_step = len(trial_history) curr_step = len(scalar_trial_history)
if curr_step < self.start_step: if curr_step < self.start_step:
return AssessResult.Good return AssessResult.Good
...@@ -106,7 +108,7 @@ class CurvefittingAssessor(Assessor): ...@@ -106,7 +108,7 @@ class CurvefittingAssessor(Assessor):
start_time = datetime.datetime.now() start_time = datetime.datetime.now()
# Predict the final result # Predict the final result
curvemodel = CurveModel(self.target_pos) curvemodel = CurveModel(self.target_pos)
predict_y = curvemodel.predict(trial_history) predict_y = curvemodel.predict(scalar_trial_history)
logger.info('Prediction done. Trial job id = %s. Predict value = %s', trial_job_id, predict_y) logger.info('Prediction done. Trial job id = %s. Predict value = %s', trial_job_id, predict_y)
if predict_y is None: if predict_y is None:
logger.info('wait for more information to predict precisely') logger.info('wait for more information to predict precisely')
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
import logging import logging
from nni.assessor import Assessor, AssessResult from nni.assessor import Assessor, AssessResult
from nni.utils import extract_scalar_history
logger = logging.getLogger('medianstop_Assessor') logger = logging.getLogger('medianstop_Assessor')
...@@ -91,20 +92,12 @@ class MedianstopAssessor(Assessor): ...@@ -91,20 +92,12 @@ class MedianstopAssessor(Assessor):
if curr_step < self._start_step: if curr_step < self._start_step:
return AssessResult.Good return AssessResult.Good
try: scalar_trial_history = extract_scalar_history(trial_history)
num_trial_history = [float(ele) for ele in trial_history] self._update_data(trial_job_id, scalar_trial_history)
except (TypeError, ValueError) as error:
logger.warning('incorrect data type or value:')
logger.exception(error)
except Exception as error:
logger.warning('unrecognized exception in medianstop_assessor:')
logger.exception(error)
self._update_data(trial_job_id, num_trial_history)
if self._high_better: if self._high_better:
best_history = max(trial_history) best_history = max(scalar_trial_history)
else: else:
best_history = min(trial_history) best_history = min(scalar_trial_history)
avg_array = [] avg_array = []
for id_ in self._completed_avg_history: for id_ in self._completed_avg_history:
......
...@@ -234,4 +234,5 @@ class MsgDispatcher(MsgDispatcherBase): ...@@ -234,4 +234,5 @@ class MsgDispatcher(MsgDispatcherBase):
if multi_thread_enabled(): if multi_thread_enabled():
self._handle_final_metric_data(data) self._handle_final_metric_data(data)
else: else:
data['value'] = to_json(data['value'])
self.enqueue_command(CommandType.ReportMetricData, data) self.enqueue_command(CommandType.ReportMetricData, data)
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# pylint: skip-file
# This file is copied from PyTorch 1.4, with bug fixes.
# Likely to be removed in future.
import torch
from tensorboard.compat.proto.config_pb2 import RunMetadata
from tensorboard.compat.proto.graph_pb2 import GraphDef
from tensorboard.compat.proto.step_stats_pb2 import StepStats, DeviceStepStats
from tensorboard.compat.proto.versions_pb2 import VersionDef
from torch.utils.tensorboard._pytorch_graph import GraphPy, CLASSTYPE_KIND, GETATTR_KIND, NodePyIO, NodePyOP
def parse(graph, trace, args=None, omit_useless_nodes=True):
"""This method parses an optimized PyTorch model graph and produces
a list of nodes and node stats for eventual conversion to TensorBoard
protobuf format.
Args:
graph (PyTorch module): The model graph to be parsed.
trace (PyTorch JIT TracedModule): The model trace to be parsed.
args (tuple): input tensor[s] for the model.
omit_useless_nodes (boolean): Whether to remove nodes from the graph.
"""
n_inputs = len(args)
scope = {}
nodes_py = GraphPy()
for node in graph.inputs():
if omit_useless_nodes:
if len(node.uses()) == 0: # number of user of the node (= number of outputs/ fanout)
continue
if node.type().kind() != CLASSTYPE_KIND:
nodes_py.append(NodePyIO(node, 'input'))
attr_to_scope = dict()
node_to_name = lambda d: str(d).split(":")[0].strip()
for node in graph.nodes():
if node.kind() == GETATTR_KIND:
attr_name = node.s('name')
node_name = node_to_name(node)
parent = node.input().node()
if parent.kind() == GETATTR_KIND: # If the parent node is not the top-level "self" node
parent_attr_name = parent.s('name')
parent_scope = attr_to_scope[node_to_name(parent)]
attr_scope = parent_scope.split('/')[-1]
attr_to_scope[node_name] = '{}/{}.{}'.format(parent_scope, attr_scope, attr_name)
else:
attr_to_scope[node_name] = '__module.{}'.format(attr_name)
# We don't need classtype nodes; scope will provide this information
if node.output().type().kind() != CLASSTYPE_KIND:
node_py = NodePyOP(node)
node_py.scopeName = attr_to_scope[node_name]
nodes_py.append(node_py)
else:
nodes_py.append(NodePyOP(node))
for i, node in enumerate(graph.outputs()): # Create sink nodes for output ops
node_py = NodePyIO(node, 'output')
node_py.debugName = "output.{}".format(i + 1)
node_py.inputs = [node.debugName()]
nodes_py.append(node_py)
def parse_traced_name(module_name):
prefix = 'TracedModule['
suffix = ']'
if module_name.startswith(prefix) and module_name.endswith(suffix):
module_name = module_name[len(prefix):-len(suffix)]
return module_name
alias_to_name = dict()
base_name = parse_traced_name(trace._name)
for name, module in trace.named_modules(prefix='__module'):
mod_name = parse_traced_name(module._name)
attr_name = name.split('.')[-1]
alias_to_name[name] = '{}[{}]'.format(mod_name, attr_name)
for node in nodes_py.nodes_op:
module_aliases = node.scopeName.split('/')[-1].split('.')
module_name = ''
for i, alias in enumerate(module_aliases):
if i == 0:
module_name = alias
node.scopeName = base_name
else:
module_name += '.' + alias
node.scopeName += '/' + (alias_to_name[module_name] if module_name in alias_to_name else alias)
nodes_py.populate_namespace_from_OP_to_IO()
return nodes_py.to_proto()
def graph(model, args, verbose=False):
"""
This method processes a PyTorch model and produces a `GraphDef` proto
that can be logged to TensorBoard.
Args:
model (PyTorch module): The model to be parsed.
args (tuple): input tensor[s] for the model.
verbose (bool): Whether to print out verbose information while
processing.
"""
with torch.onnx.set_training(model, False): # TODO: move outside of torch.onnx?
try:
trace = torch.jit.trace(model, args)
graph = trace.graph
torch._C._jit_pass_inline(graph)
except RuntimeError as e:
print(e)
print('Error occurs, No graph saved')
raise e
if verbose:
print(graph)
list_of_nodes = parse(graph, trace, args)
# We are hardcoding that this was run on CPU even though it might have actually
# run on GPU. Note this is what is shown in TensorBoard and has no bearing
# on actual execution.
# TODO: See if we can extract GPU vs CPU information from the PyTorch model
# and pass it correctly to TensorBoard.
#
# Definition of StepStats and DeviceStepStats can be found at
# https://github.com/tensorflow/tensorboard/blob/master/tensorboard/plugins/graph/tf_graph_common/test/graph-test.ts
# and
# https://github.com/tensorflow/tensorboard/blob/master/tensorboard/compat/proto/step_stats.proto
stepstats = RunMetadata(step_stats=StepStats(dev_stats=[DeviceStepStats(device="/device:CPU:0")]))
return GraphDef(node=list_of_nodes, versions=VersionDef(producer=22)), stepstats
# The producer version has been reverse engineered from standard
# TensorBoard logged data.
...@@ -23,7 +23,9 @@ class StackedLSTMCell(nn.Module): ...@@ -23,7 +23,9 @@ class StackedLSTMCell(nn.Module):
curr_c, curr_h = m(inputs, (prev_c[i], prev_h[i])) curr_c, curr_h = m(inputs, (prev_c[i], prev_h[i]))
next_c.append(curr_c) next_c.append(curr_c)
next_h.append(curr_h) next_h.append(curr_h)
inputs = curr_h[-1] # current implementation only supports batch size equals 1,
# but the algorithm does not necessarily have this limitation
inputs = curr_h[-1].view(1, -1)
return next_c, next_h return next_c, next_h
......
...@@ -2,7 +2,9 @@ ...@@ -2,7 +2,9 @@
# Licensed under the MIT license. # Licensed under the MIT license.
import logging import logging
from collections import defaultdict
import numpy as np
import torch import torch
from nni.nas.pytorch.base_mutator import BaseMutator from nni.nas.pytorch.base_mutator import BaseMutator
...@@ -15,6 +17,7 @@ class Mutator(BaseMutator): ...@@ -15,6 +17,7 @@ class Mutator(BaseMutator):
def __init__(self, model): def __init__(self, model):
super().__init__(model) super().__init__(model)
self._cache = dict() self._cache = dict()
self._connect_all = False
def sample_search(self): def sample_search(self):
""" """
...@@ -57,6 +60,74 @@ class Mutator(BaseMutator): ...@@ -57,6 +60,74 @@ class Mutator(BaseMutator):
""" """
return self.sample_final() return self.sample_final()
def status(self):
"""
Return current selection status of mutator.
Returns
-------
dict
A mapping from key of mutables to decisions. All weights (boolean type and float type)
are converted into real number values. Numpy arrays and tensors are converted into list.
"""
data = dict()
for k, v in self._cache.items():
if torch.is_tensor(v):
v = v.detach().cpu().numpy()
if isinstance(v, np.ndarray):
v = v.astype(np.float32).tolist()
data[k] = v
return data
def graph(self, inputs):
"""
Return model supernet graph.
Parameters
----------
inputs: tuple of tensor
Inputs that will be feeded into the network.
Returns
-------
dict
Containing ``node``, in Tensorboard GraphDef format.
Additional key ``mutable`` is a map from key to list of modules.
"""
if not torch.__version__.startswith("1.4"):
logger.warning("Graph is only tested with PyTorch 1.4. Other versions might not work.")
from ._graph_utils import graph
from google.protobuf import json_format
# protobuf should be installed as long as tensorboard is installed
try:
self._connect_all = True
graph_def, _ = graph(self.model, inputs, verbose=False)
result = json_format.MessageToDict(graph_def)
finally:
self._connect_all = False
# `mutable` is to map the keys to a list of corresponding modules.
# A key can be linked to multiple modules, use `dedup=False` to find them all.
result["mutable"] = defaultdict(list)
for mutable in self.mutables.traverse(deduplicate=False):
# A module will be represent in the format of
# [{"type": "Net", "name": ""}, {"type": "Cell", "name": "cell1"}, {"type": "Conv2d": "name": "conv"}]
# which will be concatenated into Net/Cell[cell1]/Conv2d[conv] in frontend.
# This format is aligned with the scope name jit gives.
modules = mutable.name.split(".")
path = [
{"type": self.model.__class__.__name__, "name": ""}
]
m = self.model
for module in modules:
m = getattr(m, module)
path.append({
"type": m.__class__.__name__,
"name": module
})
result["mutable"][mutable.key].append(path)
return result
def on_forward_layer_choice(self, mutable, *inputs): def on_forward_layer_choice(self, mutable, *inputs):
""" """
On default, this method retrieves the decision obtained previously, and select certain operations. On default, this method retrieves the decision obtained previously, and select certain operations.
...@@ -75,6 +146,11 @@ class Mutator(BaseMutator): ...@@ -75,6 +146,11 @@ class Mutator(BaseMutator):
tuple of torch.Tensor and torch.Tensor tuple of torch.Tensor and torch.Tensor
Output and mask. Output and mask.
""" """
if self._connect_all:
return self._all_connect_tensor_reduction(mutable.reduction,
[op(*inputs) for op in mutable.choices]), \
torch.ones(mutable.length)
def _map_fn(op, *inputs): def _map_fn(op, *inputs):
return op(*inputs) return op(*inputs)
...@@ -101,6 +177,9 @@ class Mutator(BaseMutator): ...@@ -101,6 +177,9 @@ class Mutator(BaseMutator):
tuple of torch.Tensor and torch.Tensor tuple of torch.Tensor and torch.Tensor
Output and mask. Output and mask.
""" """
if self._connect_all:
return self._all_connect_tensor_reduction(mutable.reduction, tensor_list), \
torch.ones(mutable.n_candidates)
mask = self._get_decision(mutable) mask = self._get_decision(mutable)
assert len(mask) == mutable.n_candidates, \ assert len(mask) == mutable.n_candidates, \
"Invalid mask, expected {} to be of length {}.".format(mask, mutable.n_candidates) "Invalid mask, expected {} to be of length {}.".format(mask, mutable.n_candidates)
...@@ -131,6 +210,13 @@ class Mutator(BaseMutator): ...@@ -131,6 +210,13 @@ class Mutator(BaseMutator):
return torch.cat(tensor_list, dim=1) return torch.cat(tensor_list, dim=1)
raise ValueError("Unrecognized reduction policy: \"{}\"".format(reduction_type)) raise ValueError("Unrecognized reduction policy: \"{}\"".format(reduction_type))
def _all_connect_tensor_reduction(self, reduction_type, tensor_list):
if reduction_type == "none":
return tensor_list
if reduction_type == "concat":
return torch.cat(tensor_list, dim=1)
return torch.stack(tensor_list).sum(0)
def _get_decision(self, mutable): def _get_decision(self, mutable):
""" """
By default, this method checks whether `mutable.key` is already in the decision cache, By default, this method checks whether `mutable.key` is already in the decision cache,
......
...@@ -9,7 +9,7 @@ if trial_env_vars.NNI_PLATFORM is None: ...@@ -9,7 +9,7 @@ if trial_env_vars.NNI_PLATFORM is None:
from .standalone import * from .standalone import *
elif trial_env_vars.NNI_PLATFORM == 'unittest': elif trial_env_vars.NNI_PLATFORM == 'unittest':
from .test import * from .test import *
elif trial_env_vars.NNI_PLATFORM in ('local', 'remote', 'pai', 'kubeflow', 'frameworkcontroller', 'paiYarn'): elif trial_env_vars.NNI_PLATFORM in ('local', 'remote', 'pai', 'kubeflow', 'frameworkcontroller', 'paiYarn', 'dlts'):
from .local import * from .local import *
else: else:
raise RuntimeError('Unknown platform %s' % trial_env_vars.NNI_PLATFORM) raise RuntimeError('Unknown platform %s' % trial_env_vars.NNI_PLATFORM)
...@@ -62,6 +62,13 @@ def extract_scalar_reward(value, scalar_key='default'): ...@@ -62,6 +62,13 @@ def extract_scalar_reward(value, scalar_key='default'):
""" """
Extract scalar reward from trial result. Extract scalar reward from trial result.
Parameters
----------
value : int, float, dict
the reported final metric data
scalar_key : str
the key name that indicates the numeric number
Raises Raises
------ ------
RuntimeError RuntimeError
...@@ -78,6 +85,26 @@ def extract_scalar_reward(value, scalar_key='default'): ...@@ -78,6 +85,26 @@ def extract_scalar_reward(value, scalar_key='default'):
return reward return reward
def extract_scalar_history(trial_history, scalar_key='default'):
"""
Extract scalar value from a list of intermediate results.
Parameters
----------
trial_history : list
accumulated intermediate results of a trial
scalar_key : str
the key name that indicates the numeric number
Raises
------
RuntimeError
Incorrect final result: the final result should be float/int,
or a dict which has a key named "default" whose value is float/int.
"""
return [extract_scalar_reward(ele, scalar_key) for ele in trial_history]
def convert_dict2tuple(value): def convert_dict2tuple(value):
""" """
convert dict type to tuple to solve unhashable problem. convert dict type to tuple to solve unhashable problem.
...@@ -90,7 +117,9 @@ def convert_dict2tuple(value): ...@@ -90,7 +117,9 @@ def convert_dict2tuple(value):
def init_dispatcher_logger(): def init_dispatcher_logger():
""" Initialize dispatcher logging configuration""" """
Initialize dispatcher logging configuration
"""
logger_file_path = 'dispatcher.log' logger_file_path = 'dispatcher.log'
if dispatcher_env_vars.NNI_LOG_DIRECTORY is not None: if dispatcher_env_vars.NNI_LOG_DIRECTORY is not None:
logger_file_path = os.path.join(dispatcher_env_vars.NNI_LOG_DIRECTORY, logger_file_path) logger_file_path = os.path.join(dispatcher_env_vars.NNI_LOG_DIRECTORY, logger_file_path)
......
This diff is collapsed.
...@@ -59,29 +59,34 @@ class Customize extends React.Component<CustomizeProps, CustomizeState> { ...@@ -59,29 +59,34 @@ class Customize extends React.Component<CustomizeProps, CustomizeState> {
}); });
// true: parameters are wrong // true: parameters are wrong
let flag = false; let parametersIllegal = false;
Object.keys(customized).map(item => { Object.keys(customized).map(item => {
if (item !== 'tag') { if (item !== 'tag') {
// unified data type // unified data type
if (typeof copyTrialParameter[item] === 'number' && typeof customized[item] === 'string') { if (typeof copyTrialParameter[item] === 'number' && typeof customized[item] === 'string') {
customized[item] = JSON.parse(customized[item]); customized[item] = JSON.parse(customized[item]);
} }
if (searchSpace[item] === undefined) {
// sometimes the schema of trial parameters is different from search space
// e.g. Batch Tuner
return;
}
if (searchSpace[item]._type === 'choice') { if (searchSpace[item]._type === 'choice') {
if (searchSpace[item]._value.find((val: string | number) => if (searchSpace[item]._value.find((val: string | number) =>
val === customized[item]) === undefined) { val === customized[item]) === undefined) {
flag = true; parametersIllegal = true;
return; return;
} }
} else { } else {
if (customized[item] < searchSpace[item]._value[0] if (customized[item] < searchSpace[item]._value[0]
|| customized[item] > searchSpace[item]._value[1]) { || customized[item] > searchSpace[item]._value[1]) {
flag = true; parametersIllegal = true;
return; return;
} }
} }
} }
}); });
if (flag !== false) { if (parametersIllegal !== false) {
// open the warning modal // open the warning modal
this.setState(() => ({ isShowWarning: true, customParameters: customized })); this.setState(() => ({ isShowWarning: true, customParameters: customized }));
} else { } else {
...@@ -269,4 +274,4 @@ class Customize extends React.Component<CustomizeProps, CustomizeState> { ...@@ -269,4 +274,4 @@ class Customize extends React.Component<CustomizeProps, CustomizeState> {
} }
} }
export default Customize; export default Customize;
\ No newline at end of file
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment