Unverified Commit 75028bd7 authored by SparkSnail's avatar SparkSnail Committed by GitHub
Browse files

Merge pull request #235 from microsoft/master

merge master
parents 1d74ae5e 2e42d1d8
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
import { TrialConfig } from "training_service/common/trialConfig";
export class DLTSTrialConfig extends TrialConfig {
public constructor(
command: string,
codeDir: string,
gpuNum: number,
public readonly image: string
) {
super(command, codeDir, gpuNum);
}
}
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
import {
TrialJobDetail,
TrialJobStatus,
TrialJobApplicationForm
} from "../../common/trainingService";
export class DLTSTrialJobDetail implements TrialJobDetail {
public startTime?: number;
public endTime?: number;
public tags?: string[];
public url?: string;
public isEarlyStopped?: boolean;
// DLTS staff
public dltsJobId?: string;
public dltsPaused: boolean = false;
public constructor (
public id: string,
public status: TrialJobStatus,
public submitTime: number,
public workingDirectory: string,
public form: TrialJobApplicationForm,
// DLTS staff
public dltsJobName: string,
) {}
}
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .compressor import LayerInfo, Compressor, Pruner, Quantizer
from .compressor import Compressor, Pruner, Quantizer
from .pruners import *
from .weight_rank_filter_pruners import *
from .activation_rank_filter_pruners import *
......
......@@ -16,7 +16,7 @@ class ActivationRankFilterPruner(Pruner):
to achieve a preset level of network sparsity.
"""
def __init__(self, model, config_list, activation='relu', statistics_batch_num=1):
def __init__(self, model, config_list, optimizer=None, activation='relu', statistics_batch_num=1):
"""
Parameters
----------
......@@ -25,17 +25,23 @@ class ActivationRankFilterPruner(Pruner):
config_list : list
support key for each list item:
- sparsity: percentage of convolutional filters to be pruned.
optimizer: torch.optim.Optimizer
Optimizer used to train model
activation : str
Activation function
statistics_batch_num : int
Num of batches for activation statistics
"""
super().__init__(model, config_list)
self.register_buffer("if_calculated", torch.tensor(0)) # pylint: disable=not-callable
super().__init__(model, config_list, optimizer)
self.set_wrappers_attribute("if_calculated", False)
self.set_wrappers_attribute("collected_activation", [])
self.statistics_batch_num = statistics_batch_num
self.collected_activation = {}
self.hooks = {}
def collector(module_, input_, output):
if len(module_.collected_activation) < self.statistics_batch_num:
module_.collected_activation.append(self.activation(output.detach().cpu()))
self.add_activation_collector(collector)
assert activation in ['relu', 'relu6']
if activation == 'relu':
self.activation = torch.nn.functional.relu
......@@ -44,33 +50,10 @@ class ActivationRankFilterPruner(Pruner):
else:
self.activation = None
def compress(self):
"""
Compress the model, register a hook for collecting activations.
"""
if self.modules_wrapper is not None:
# already compressed
return self.bound_model
else:
self.modules_wrapper = []
modules_to_compress = self.detect_modules_to_compress()
for layer, config in modules_to_compress:
wrapper = self._wrap_modules(layer, config)
self.modules_wrapper.append(wrapper)
self.collected_activation[layer.name] = []
def _hook(module_, input_, output, name=layer.name):
if len(self.collected_activation[name]) < self.statistics_batch_num:
self.collected_activation[name].append(self.activation(output.detach().cpu()))
wrapper.module.register_forward_hook(_hook)
self._wrap_model()
return self.bound_model
def get_mask(self, base_mask, activations, num_prune):
raise NotImplementedError('{} get_mask is not implemented'.format(self.__class__.__name__))
def calc_mask(self, layer, config, **kwargs):
def calc_mask(self, wrapper, **kwargs):
"""
Calculate the mask of given layer.
Filters with the smallest importance criterion which is calculated from the activation are masked.
......@@ -88,29 +71,30 @@ class ActivationRankFilterPruner(Pruner):
dictionary for storing masks
"""
weight = layer.module.weight.data
op_type = layer.type
weight = wrapper.module.weight.data
op_type = wrapper.type
config = wrapper.config
assert 0 <= config.get('sparsity') < 1, "sparsity must in the range [0, 1)"
assert op_type in ['Conv2d'], "only support Conv2d"
assert op_type in config.get('op_types')
if_calculated = kwargs["if_calculated"]
if if_calculated:
if wrapper.if_calculated:
return None
mask_weight = torch.ones(weight.size()).type_as(weight).detach()
if hasattr(layer.module, 'bias') and layer.module.bias is not None:
mask_bias = torch.ones(layer.module.bias.size()).type_as(layer.module.bias).detach()
if hasattr(wrapper.module, 'bias') and wrapper.module.bias is not None:
mask_bias = torch.ones(wrapper.module.bias.size()).type_as(wrapper.module.bias).detach()
else:
mask_bias = None
mask = {'weight': mask_weight, 'bias': mask_bias}
mask = {'weight_mask': mask_weight, 'bias_mask': mask_bias}
try:
filters = weight.size(0)
num_prune = int(filters * config.get('sparsity'))
if filters < 2 or num_prune < 1 or len(self.collected_activation[layer.name]) < self.statistics_batch_num:
if filters < 2 or num_prune < 1 or len(wrapper.collected_activation) < self.statistics_batch_num:
return mask
mask = self.get_mask(mask, self.collected_activation[layer.name], num_prune)
mask = self.get_mask(mask, wrapper.collected_activation, num_prune)
finally:
if len(self.collected_activation[layer.name]) == self.statistics_batch_num:
if_calculated.copy_(torch.tensor(1)) # pylint: disable=not-callable
if len(wrapper.collected_activation) == self.statistics_batch_num:
wrapper.if_calculated = True
return mask
......@@ -123,7 +107,7 @@ class ActivationAPoZRankFilterPruner(ActivationRankFilterPruner):
https://arxiv.org/abs/1607.03250
"""
def __init__(self, model, config_list, activation='relu', statistics_batch_num=1):
def __init__(self, model, config_list, optimizer=None, activation='relu', statistics_batch_num=1):
"""
Parameters
----------
......@@ -132,12 +116,14 @@ class ActivationAPoZRankFilterPruner(ActivationRankFilterPruner):
config_list : list
support key for each list item:
- sparsity: percentage of convolutional filters to be pruned.
optimizer: torch.optim.Optimizer
Optimizer used to train model
activation : str
Activation function
statistics_batch_num : int
Num of batches for activation statistics
"""
super().__init__(model, config_list, activation, statistics_batch_num)
super().__init__(model, config_list, optimizer, activation, statistics_batch_num)
def get_mask(self, base_mask, activations, num_prune):
"""
......@@ -161,9 +147,9 @@ class ActivationAPoZRankFilterPruner(ActivationRankFilterPruner):
apoz = self._calc_apoz(activations)
prune_indices = torch.argsort(apoz, descending=True)[:num_prune]
for idx in prune_indices:
base_mask['weight'][idx] = 0.
if base_mask['bias'] is not None:
base_mask['bias'][idx] = 0.
base_mask['weight_mask'][idx] = 0.
if base_mask['bias_mask'] is not None:
base_mask['bias_mask'][idx] = 0.
return base_mask
def _calc_apoz(self, activations):
......@@ -195,7 +181,7 @@ class ActivationMeanRankFilterPruner(ActivationRankFilterPruner):
https://arxiv.org/abs/1611.06440
"""
def __init__(self, model, config_list, activation='relu', statistics_batch_num=1):
def __init__(self, model, config_list, optimizer=None, activation='relu', statistics_batch_num=1):
"""
Parameters
----------
......@@ -204,12 +190,14 @@ class ActivationMeanRankFilterPruner(ActivationRankFilterPruner):
config_list : list
support key for each list item:
- sparsity: percentage of convolutional filters to be pruned.
optimizer: torch.optim.Optimizer
Optimizer used to train model
activation : str
Activation function
statistics_batch_num : int
Num of batches for activation statistics
"""
super().__init__(model, config_list, activation, statistics_batch_num)
super().__init__(model, config_list, optimizer, activation, statistics_batch_num)
def get_mask(self, base_mask, activations, num_prune):
"""
......@@ -233,9 +221,9 @@ class ActivationMeanRankFilterPruner(ActivationRankFilterPruner):
mean_activation = self._cal_mean_activation(activations)
prune_indices = torch.argsort(mean_activation)[:num_prune]
for idx in prune_indices:
base_mask['weight'][idx] = 0.
if base_mask['bias'] is not None:
base_mask['bias'][idx] = 0.
base_mask['weight_mask'][idx] = 0.
if base_mask['bias_mask'] is not None:
base_mask['bias_mask'][idx] = 0.
return base_mask
def _cal_mean_activation(self, activations):
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import types
import logging
import torch
from . import default_layers
......@@ -20,12 +21,13 @@ def _setattr(model, name, module):
model = getattr(model, name)
setattr(model, name_list[-1], module)
class Compressor:
"""
Abstract base PyTorch compressor
"""
def __init__(self, model, config_list):
def __init__(self, model, config_list, optimizer=None):
"""
Record necessary info in class members
......@@ -35,15 +37,27 @@ class Compressor:
the model user wants to compress
config_list : list
the configurations that users specify for compression
optimizer: pytorch optimizer
optimizer used to train the model
"""
self.bound_model = model
self.config_list = config_list
self.optimizer = optimizer
self.modules_to_compress = None
self.modules_wrapper = None
self.buffers = {}
self.modules_wrapper = []
self.is_wrapped = False
def detect_modules_to_compress(self):
self._fwd_hook_handles = {}
self._fwd_hook_id = 0
for layer, config in self._detect_modules_to_compress():
wrapper = self._wrap_modules(layer, config)
self.modules_wrapper.append(wrapper)
self._wrap_model()
def _detect_modules_to_compress(self):
"""
detect all modules should be compressed, and save the result in `self.modules_to_compress`.
The model will be instrumented and user should never edit it after calling this method.
......@@ -87,26 +101,26 @@ class Compressor:
torch.nn.Module
model with specified modules compressed.
"""
if self.modules_wrapper is not None:
# already compressed
return self.bound_model
else:
self.modules_wrapper = []
modules_to_compress = self.detect_modules_to_compress()
for layer, config in modules_to_compress:
wrapper = self._wrap_modules(layer, config)
self.modules_wrapper.append(wrapper)
self._wrap_model()
return self.bound_model
def register_buffer(self, name, value):
def set_wrappers_attribute(self, name, value):
"""
To register buffers used in wrapped module's forward method.
To register attributes used in wrapped module's forward method.
If the type of the value is Torch.tensor, then this value is registered as a buffer in wrapper,
which will be saved by model.state_dict. Otherwise, this value is just a regular variable in wrapper.
Parameters
----------
name : str
name of the variable
value: any
value of the variable
"""
self.buffers[name] = value
for wrapper in self.get_modules_wrapper():
if isinstance(value, torch.Tensor):
wrapper.register_buffer(name, value.clone())
else:
setattr(wrapper, name, value)
def get_modules_to_compress(self):
"""
......@@ -180,11 +194,7 @@ class Compressor:
epoch : num
the current epoch number
"""
def step(self):
"""
If user want to update model every step, user can override this method
"""
pass
def _wrap_modules(self, layer, config):
"""
......@@ -200,6 +210,34 @@ class Compressor:
raise NotImplementedError()
def add_activation_collector(self, collector):
self._fwd_hook_id += 1
self._fwd_hook_handles[self._fwd_hook_id] = []
for wrapper in self.get_modules_wrapper():
handle = wrapper.register_forward_hook(collector)
self._fwd_hook_handles[self._fwd_hook_id].append(handle)
return self._fwd_hook_id
def remove_activation_collector(self, fwd_hook_id):
if fwd_hook_id not in self._fwd_hook_handles:
raise ValueError("%s is not a valid collector id" % str(fwd_hook_id))
for handle in self._fwd_hook_handles[fwd_hook_id]:
handle.remove()
del self._fwd_hook_handles[fwd_hook_id]
def patch_optimizer(self, *tasks):
def patch_step(old_step):
def new_step(_, *args, **kwargs):
# call origin optimizer step method
output = old_step(*args, **kwargs)
# calculate mask
for task in tasks:
task()
return output
return new_step
if self.optimizer is not None:
self.optimizer.step = types.MethodType(patch_step(self.optimizer.step), self.optimizer)
class PrunerModuleWrapper(torch.nn.Module):
def __init__(self, module, module_name, module_type, config, pruner):
"""
......@@ -226,7 +264,6 @@ class PrunerModuleWrapper(torch.nn.Module):
# config and pruner
self.config = config
self.pruner = pruner
self.registered_buffers = []
# register buffer for mask
self.register_buffer("weight_mask", torch.ones(self.module.weight.shape))
......@@ -234,29 +271,11 @@ class PrunerModuleWrapper(torch.nn.Module):
self.register_buffer("bias_mask", torch.ones(self.module.bias.shape))
else:
self.register_buffer("bias_mask", None)
self.registered_buffers.append('weight_mask')
self.registered_buffers.append('bias_mask')
# register user specified buffer
for name in self.pruner.buffers:
self.register_buffer(name, self.pruner.buffers[name].clone())
self.registered_buffers.append(name)
def get_registered_buffers(self):
buffers = {}
for name in self.registered_buffers:
buffers[name] = getattr(self, name)
return buffers
def forward(self, *inputs):
mask = self.pruner.calc_mask(LayerInfo(self.name, self.module), self.config, **self.get_registered_buffers())
if mask is not None:
self.weight_mask.copy_(mask['weight'])
# apply mask to weight
# apply mask to weight, bias
self.module.weight.data = self.module.weight.data.mul_(self.weight_mask)
# apply mask to bias
if hasattr(self.module, 'bias') and self.module.bias is not None:
if mask is not None and 'bias' in mask:
self.bias_mask.copy_(mask['bias'])
self.module.bias.data = self.module.bias.data.mul_(self.bias_mask)
return self.module(*inputs)
......@@ -272,10 +291,24 @@ class Pruner(Compressor):
"""
def __init__(self, model, config_list):
super().__init__(model, config_list)
def __init__(self, model, config_list, optimizer=None):
super().__init__(model, config_list, optimizer)
if optimizer is not None:
self.patch_optimizer(self.update_mask)
def compress(self):
self.update_mask()
return self.bound_model
def update_mask(self):
for wrapper in self.get_modules_wrapper():
masks = self.calc_mask(wrapper)
if masks is not None:
for k in masks:
assert hasattr(wrapper, k), "there is no attribute '%s' in wrapper" % k
setattr(wrapper, k, masks[k])
def calc_mask(self, layer, config, **kwargs):
def calc_mask(self, wrapper, **kwargs):
"""
Pruners should overload this method to provide mask for weight tensors.
The mask must have the same shape and type comparing to the weight.
......@@ -284,10 +317,8 @@ class Pruner(Compressor):
Parameters
----------
layer : LayerInfo
calculate mask for `layer`'s weight
config : dict
the configuration for generating the mask
wrapper : Module
calculate mask for `wrapper.module`'s weight
"""
raise NotImplementedError("Pruners must overload calc_mask()")
......@@ -327,8 +358,6 @@ class Pruner(Compressor):
device of the model, used to place the dummy input tensor for exporting onnx file.
the tensor is placed on cpu if ```device``` is None
"""
# if self.detect_modules_to_compress() and not self.mask_dict:
# _logger.warning('You may not use self.mask_dict in base Pruner class to record masks')
assert model_path is not None, 'model_path must be specified'
mask_dict = {}
self._unwrap_model() # used for generating correct state_dict name without wrapper state
......@@ -404,7 +433,6 @@ class QuantizerModuleWrapper(torch.nn.Module):
# config and pruner
self.config = config
self.quantizer = quantizer
self.registered_buffers = []
# register buffer and parameter
# old_weight is used to store origin weight and weight is used to store quantized weight
......@@ -418,35 +446,18 @@ class QuantizerModuleWrapper(torch.nn.Module):
delattr(self.module, 'weight')
self.module.register_buffer('weight', self.module.old_weight)
# register user specified buffer
for name in self.quantizer.buffers:
self.register_buffer(name, self.quantizer.buffers[name].clone())
self.registered_buffers.append(name)
def get_registered_buffers(self):
buffers = {}
for name in self.registered_buffers:
buffers[name] = getattr(self, name)
return buffers
def forward(self, *inputs):
if 'input' in self.config['quant_types']:
inputs = self.quantizer.quant_grad.apply(
inputs,
QuantType.QUANT_INPUT,
self.quantizer.quantize_input,
self.config,
LayerInfo(self.name, self.module),
**self.get_registered_buffers())
self)
if 'weight' in self.config['quant_types'] and _check_weight(self.module):
new_weight = self.quantizer.quant_grad.apply(
self.module.old_weight,
QuantType.QUANT_WEIGHT,
self.quantizer.quantize_weight,
self.config,
LayerInfo(self.name, self.module),
**self.get_registered_buffers())
self)
self.module.weight = new_weight
result = self.module(*inputs)
else:
......@@ -456,10 +467,7 @@ class QuantizerModuleWrapper(torch.nn.Module):
result = self.quantizer.quant_grad.apply(
result,
QuantType.QUANT_OUTPUT,
self.quantizer.quantize_output,
self.config,
LayerInfo(self.name, self.module),
**self.get_registered_buffers())
self)
return result
class Quantizer(Compressor):
......@@ -467,11 +475,18 @@ class Quantizer(Compressor):
Base quantizer for pytorch quantizer
"""
def __init__(self, model, config_list):
super().__init__(model, config_list)
def __init__(self, model, config_list, optimizer=None):
super().__init__(model, config_list, optimizer)
self.quant_grad = QuantGrad
if self.optimizer is not None:
self.patch_optimizer(self.step_with_optimizer)
for wrapper in self.get_modules_wrapper():
if 'weight' in wrapper.config['quant_types']:
# old_weight is registered to keep track of weight before quantization
# and it is trainable, therefore, it should be added to optimizer.
self.optimizer.add_param_group({"params": wrapper.module.old_weight})
def quantize_weight(self, weight, config, op, op_type, op_name):
def quantize_weight(self, weight, wrapper, **kwargs):
"""
quantize should overload this method to quantize weight.
This method is effectively hooked to :meth:`forward` of the model.
......@@ -479,12 +494,12 @@ class Quantizer(Compressor):
----------
weight : Tensor
weight that needs to be quantized
config : dict
the configuration for weight quantization
wrapper : QuantizerModuleWrapper
the wrapper for origin module
"""
raise NotImplementedError('Quantizer must overload quantize_weight()')
def quantize_output(self, output, config, op, op_type, op_name):
def quantize_output(self, output, wrapper, **kwargs):
"""
quantize should overload this method to quantize output.
This method is effectively hooked to :meth:`forward` of the model.
......@@ -492,12 +507,12 @@ class Quantizer(Compressor):
----------
output : Tensor
output that needs to be quantized
config : dict
the configuration for output quantization
wrapper : QuantizerModuleWrapper
the wrapper for origin module
"""
raise NotImplementedError('Quantizer must overload quantize_output()')
def quantize_input(self, *inputs, config, op, op_type, op_name):
def quantize_input(self, *inputs, wrapper, **kwargs):
"""
quantize should overload this method to quantize input.
This method is effectively hooked to :meth:`forward` of the model.
......@@ -505,8 +520,8 @@ class Quantizer(Compressor):
----------
inputs : Tensor
inputs that needs to be quantized
config : dict
the configuration for inputs quantization
wrapper : QuantizerModuleWrapper
the wrapper for origin module
"""
raise NotImplementedError('Quantizer must overload quantize_input()')
......@@ -532,6 +547,9 @@ class Quantizer(Compressor):
return QuantizerModuleWrapper(layer.module, layer.name, layer.type, config, self)
def step_with_optimizer(self):
pass
class QuantType:
"""
Enum class for quantization type.
......@@ -540,6 +558,7 @@ class QuantType:
QUANT_WEIGHT = 1
QUANT_OUTPUT = 2
class QuantGrad(torch.autograd.Function):
"""
Base class for overriding backward function of quantization operation.
......@@ -566,15 +585,22 @@ class QuantGrad(torch.autograd.Function):
return grad_output
@staticmethod
def forward(ctx, tensor, quant_type, quant_func, config, layer, **kwargs):
def forward(ctx, tensor, quant_type, wrapper, **kwargs):
ctx.save_for_backward(tensor, torch.Tensor([quant_type]))
return quant_func(tensor, config, op=layer.module, op_type=layer.type, op_name=layer.name, **kwargs)
if quant_type == QuantType.QUANT_INPUT:
return wrapper.quantizer.quantize_input(tensor, wrapper, **kwargs)
elif quant_type == QuantType.QUANT_WEIGHT:
return wrapper.quantizer.quantize_weight(tensor, wrapper, **kwargs)
elif quant_type == QuantType.QUANT_OUTPUT:
return wrapper.quantizer.quantize_output(tensor, wrapper, **kwargs)
else:
raise ValueError("unrecognized QuantType.")
@classmethod
def backward(cls, ctx, grad_output):
tensor, quant_type = ctx.saved_variables
output = cls.quant_backward(tensor, grad_output, quant_type)
return output, None, None, None, None, None
return output, None, None, None
def _check_weight(module):
try:
......
......@@ -16,7 +16,7 @@ class LevelPruner(Pruner):
Prune to an exact pruning level specification
"""
def __init__(self, model, config_list):
def __init__(self, model, config_list, optimizer=None):
"""
Parameters
----------
......@@ -24,38 +24,39 @@ class LevelPruner(Pruner):
Model to be pruned
config_list : list
List on pruning configs
optimizer: torch.optim.Optimizer
Optimizer used to train model
"""
super().__init__(model, config_list)
self.register_buffer("if_calculated", torch.tensor(0)) # pylint: disable=not-callable
super().__init__(model, config_list, optimizer)
self.set_wrappers_attribute("if_calculated", False)
def calc_mask(self, layer, config, **kwargs):
def calc_mask(self, wrapper, **kwargs):
"""
Calculate the mask of given layer
Parameters
----------
layer : LayerInfo
the layer to instrument the compression operation
config : dict
layer's pruning config
wrapper : Module
the module to instrument the compression operation
Returns
-------
dict
dictionary for storing masks
"""
weight = layer.module.weight.data
if_calculated = kwargs["if_calculated"]
config = wrapper.config
weight = wrapper.module.weight.data
if not if_calculated:
if not wrapper.if_calculated:
w_abs = weight.abs()
k = int(weight.numel() * config['sparsity'])
if k == 0:
return torch.ones(weight.shape).type_as(weight)
threshold = torch.topk(w_abs.view(-1), k, largest=False)[0].max()
mask_weight = torch.gt(w_abs, threshold).type_as(weight)
mask = {'weight': mask_weight}
if_calculated.copy_(torch.tensor(1)) # pylint: disable=not-callable
mask = {'weight_mask': mask_weight}
wrapper.if_calculated = True
return mask
else:
return None
......@@ -71,7 +72,7 @@ class AGP_Pruner(Pruner):
https://arxiv.org/pdf/1710.01878.pdf
"""
def __init__(self, model, config_list):
def __init__(self, model, config_list, optimizer):
"""
Parameters
----------
......@@ -79,50 +80,51 @@ class AGP_Pruner(Pruner):
Model to be pruned
config_list : list
List on pruning configs
optimizer: torch.optim.Optimizer
Optimizer used to train model
"""
super().__init__(model, config_list)
super().__init__(model, config_list, optimizer)
assert isinstance(optimizer, torch.optim.Optimizer), "AGP pruner is an iterative pruner, please pass optimizer of the model to it"
self.now_epoch = 0
self.register_buffer("if_calculated", torch.tensor(0)) # pylint: disable=not-callable
self.set_wrappers_attribute("if_calculated", False)
def calc_mask(self, layer, config, **kwargs):
def calc_mask(self, wrapper, **kwargs):
"""
Calculate the mask of given layer.
Scale factors with the smallest absolute value in the BN layer are masked.
Parameters
----------
layer : LayerInfo
wrapper : Module
the layer to instrument the compression operation
config : dict
layer's pruning config
kwargs: dict
buffers registered in __init__ function
Returns
-------
dict
dictionary for storing masks
"""
weight = layer.module.weight.data
config = wrapper.config
weight = wrapper.module.weight.data
start_epoch = config.get('start_epoch', 0)
freq = config.get('frequency', 1)
if_calculated = kwargs["if_calculated"]
if if_calculated:
if wrapper.if_calculated:
return None
if not (self.now_epoch >= start_epoch and (self.now_epoch - start_epoch) % freq == 0):
return None
mask = {'weight': kwargs['weight_mask'] if 'weight_mask' in kwargs else torch.ones(weight.shape).type_as(weight)}
mask = {'weight_mask': wrapper.weight_mask}
target_sparsity = self.compute_target_sparsity(config)
k = int(weight.numel() * target_sparsity)
if k == 0 or target_sparsity >= 1 or target_sparsity <= 0:
return mask
# if we want to generate new mask, we should update weigth first
w_abs = weight.abs() * mask['weight']
w_abs = weight.abs() * mask['weight_mask']
threshold = torch.topk(w_abs.view(-1), k, largest=False)[0].max()
new_mask = {'weight': torch.gt(w_abs, threshold).type_as(weight)}
if_calculated.copy_(torch.tensor(1)) # pylint: disable=not-callable
new_mask = {'weight_mask': torch.gt(w_abs, threshold).type_as(weight)}
wrapper.if_calculated = True
return new_mask
......@@ -180,62 +182,64 @@ class SlimPruner(Pruner):
https://arxiv.org/pdf/1708.06519.pdf
"""
def __init__(self, model, config_list):
def __init__(self, model, config_list, optimizer=None):
"""
Parameters
----------
model : torch.nn.module
Model to be pruned
config_list : list
support key for each list item:
- sparsity: percentage of convolutional filters to be pruned.
optimizer: torch.optim.Optimizer
Optimizer used to train model
"""
super().__init__(model, config_list)
super().__init__(model, config_list, optimizer)
weight_list = []
if len(config_list) > 1:
logger.warning('Slim pruner only supports 1 configuration')
config = config_list[0]
for (layer, config) in self.detect_modules_to_compress():
for (layer, config) in self.get_modules_to_compress():
assert layer.type == 'BatchNorm2d', 'SlimPruner only supports 2d batch normalization layer pruning'
weight_list.append(layer.module.weight.data.abs().clone())
all_bn_weights = torch.cat(weight_list)
k = int(all_bn_weights.shape[0] * config['sparsity'])
self.global_threshold = torch.topk(all_bn_weights.view(-1), k, largest=False)[0].max()
self.register_buffer("if_calculated", torch.tensor(0)) # pylint: disable=not-callable
self.set_wrappers_attribute("if_calculated", False)
def calc_mask(self, layer, config, **kwargs):
def calc_mask(self, wrapper, **kwargs):
"""
Calculate the mask of given layer.
Scale factors with the smallest absolute value in the BN layer are masked.
Parameters
----------
layer : LayerInfo
wrapper : Module
the layer to instrument the compression operation
config : dict
layer's pruning config
kwargs: dict
buffers registered in __init__ function
Returns
-------
dict
dictionary for storing masks
"""
weight = layer.module.weight.data
op_type = layer.type
if_calculated = kwargs["if_calculated"]
config = wrapper.config
weight = wrapper.module.weight.data
op_type = wrapper.type
assert op_type == 'BatchNorm2d', 'SlimPruner only supports 2d batch normalization layer pruning'
if if_calculated:
if wrapper.if_calculated:
return None
base_mask = torch.ones(weight.size()).type_as(weight).detach()
mask = {'weight': base_mask.detach(), 'bias': base_mask.clone().detach()}
mask = {'weight_mask': base_mask.detach(), 'bias_mask': base_mask.clone().detach()}
filters = weight.size(0)
num_prune = int(filters * config.get('sparsity'))
if filters >= 2 and num_prune >= 1:
w_abs = weight.abs()
mask_weight = torch.gt(w_abs, self.global_threshold).type_as(weight)
mask_bias = mask_weight.clone()
mask = {'weight': mask_weight.detach(), 'bias': mask_bias.detach()}
if_calculated.copy_(torch.tensor(1)) # pylint: disable=not-callable
mask = {'weight_mask': mask_weight.detach(), 'bias_mask': mask_bias.detach()}
wrapper.if_calculated = True
return mask
class LotteryTicketPruner(Pruner):
......@@ -250,7 +254,7 @@ class LotteryTicketPruner(Pruner):
5. Repeat step 2, 3, and 4.
"""
def __init__(self, model, config_list, optimizer, lr_scheduler=None, reset_weights=True):
def __init__(self, model, config_list, optimizer=None, lr_scheduler=None, reset_weights=True):
"""
Parameters
----------
......@@ -267,7 +271,7 @@ class LotteryTicketPruner(Pruner):
reset_weights : bool
Whether reset weights and optimizer at the beginning of each round.
"""
super().__init__(model, config_list)
super().__init__(model, config_list, optimizer)
self.curr_prune_iteration = None
self.prune_iterations = self._validate_config(config_list)
......@@ -307,20 +311,16 @@ class LotteryTicketPruner(Pruner):
k = int(w_abs.numel() * curr_sparsity)
threshold = torch.topk(w_abs.view(-1), k, largest=False).values.max()
mask = torch.gt(w_abs, threshold).type_as(weight)
return {'weight': mask}
return {'weight_mask': mask}
def calc_mask(self, layer, config, **kwargs):
def calc_mask(self, wrapper, **kwargs):
"""
Generate mask for the given ``weight``.
Parameters
----------
layer : LayerInfo
wrapper : Module
The layer to be pruned
config : dict
Pruning configurations for this weight
kwargs : dict
Auxiliary information
Returns
-------
......@@ -355,7 +355,7 @@ class LotteryTicketPruner(Pruner):
assert self.curr_prune_iteration < self.prune_iterations + 1, 'Exceed the configured prune_iterations'
modules_wrapper = self.get_modules_wrapper()
modules_to_compress = self.detect_modules_to_compress()
modules_to_compress = self.get_modules_to_compress()
for layer, config in modules_to_compress:
module_wrapper = None
for wrapper in modules_wrapper:
......@@ -367,7 +367,7 @@ class LotteryTicketPruner(Pruner):
sparsity = config.get('sparsity')
mask = self._calc_mask(layer.module.weight.data, sparsity, module_wrapper.weight_mask)
# TODO: directly use weight_mask is not good
module_wrapper.weight_mask.copy_(mask['weight'])
module_wrapper.weight_mask = mask['weight_mask']
# there is no mask for bias
# reinit weights back to original after new masks are generated
......
......@@ -13,14 +13,14 @@ logger = logging.getLogger(__name__)
class NaiveQuantizer(Quantizer):
"""quantize weight to 8 bits
"""
def __init__(self, model, config_list):
super().__init__(model, config_list)
def __init__(self, model, config_list, optimizer=None):
super().__init__(model, config_list, optimizer)
self.layer_scale = {}
def quantize_weight(self, weight, config, op_name, **kwargs):
def quantize_weight(self, weight, wrapper, **kwargs):
new_scale = weight.abs().max() / 127
scale = max(self.layer_scale.get(op_name, 0), new_scale)
self.layer_scale[op_name] = scale
scale = max(self.layer_scale.get(wrapper.name, 0), new_scale)
self.layer_scale[wrapper.name] = scale
orig_type = weight.type() # TODO: user layer
return weight.div(scale).type(torch.int8).type(orig_type).mul(scale)
......@@ -104,7 +104,7 @@ class QAT_Quantizer(Quantizer):
Quantization and Training of Neural Networks for Efficient Integer-Arithmetic-Only Inference
http://openaccess.thecvf.com/content_cvpr_2018/papers/Jacob_Quantization_and_Training_CVPR_2018_paper.pdf
"""
def __init__(self, model, config_list):
def __init__(self, model, config_list, optimizer=None):
"""
Parameters
----------
......@@ -124,9 +124,9 @@ class QAT_Quantizer(Quantizer):
- op_types : list of string
types of nn.module you want to apply quantization, eg. 'Conv2d'
"""
super().__init__(model, config_list)
super().__init__(model, config_list, optimizer)
self.steps = 1
modules_to_compress = self.detect_modules_to_compress()
modules_to_compress = self.get_modules_to_compress()
for layer, config in modules_to_compress:
layer.module.register_buffer("zero_point", None)
layer.module.register_buffer("scale", None)
......@@ -181,7 +181,9 @@ class QAT_Quantizer(Quantizer):
real_val = op.scale * (quantized_val - op.zero_point)
return real_val
def quantize_weight(self, weight, config, op, **kwargs):
def quantize_weight(self, weight, wrapper, **kwargs):
config = wrapper.config
module = wrapper.module
weight_bits = get_bits_length(config, 'weight')
quant_start_step = config.get('quant_start_step', 0)
assert weight_bits >= 1, "quant bits length should be at least 1"
......@@ -189,12 +191,14 @@ class QAT_Quantizer(Quantizer):
if quant_start_step > self.steps:
return weight
rmin, rmax = torch.min(weight), torch.max(weight)
op.scale, op.zero_point = update_quantization_param(weight_bits, rmin, rmax)
out = self._quantize(weight_bits, op, weight)
out = self._dequantize(op, out)
module.scale, module.zero_point = update_quantization_param(weight_bits, rmin, rmax)
out = self._quantize(weight_bits, module, weight)
out = self._dequantize(module, out)
return out
def quantize_output(self, output, config, op, **kwargs):
def quantize_output(self, output, wrapper, **kwargs):
config = wrapper.config
module = wrapper.module
output_bits = get_bits_length(config, 'output')
quant_start_step = config.get('quant_start_step', 0)
assert output_bits >= 1, "quant bits length should be at least 1"
......@@ -203,18 +207,18 @@ class QAT_Quantizer(Quantizer):
return output
current_min, current_max = torch.min(output), torch.max(output)
op.tracked_min_biased, op.tracked_min = update_ema(op.tracked_min_biased, current_min, op.ema_decay, self.steps)
op.tracked_max_biased, op.tracked_max = update_ema(op.tracked_max_biased, current_max, op.ema_decay, self.steps)
op.scale, op.zero_point = update_quantization_param(output_bits, op.tracked_min, op.tracked_max)
out = self._quantize(output_bits, op, output)
out = self._dequantize(op, out)
module.tracked_min_biased, module.tracked_min = update_ema(module.tracked_min_biased, current_min, module.ema_decay, self.steps)
module.tracked_max_biased, module.tracked_max = update_ema(module.tracked_max_biased, current_max, module.ema_decay, self.steps)
module.scale, module.zero_point = update_quantization_param(output_bits, module.tracked_min, module.tracked_max)
out = self._quantize(output_bits, module, output)
out = self._dequantize(module, out)
return out
def fold_bn(self, config, **kwargs):
# TODO simulate folded weight
pass
def step(self):
def step_with_optimizer(self):
"""
override `compressor` `step` method, quantization only happens after certain number of steps
"""
......@@ -226,11 +230,11 @@ class DoReFaQuantizer(Quantizer):
Zhou et al., DoReFa-Net: Training Low Bitwidth Convolutional Neural Networks with Low Bitwidth Gradients
(https://arxiv.org/abs/1606.06160)
"""
def __init__(self, model, config_list):
super().__init__(model, config_list)
def __init__(self, model, config_list, optimizer=None):
super().__init__(model, config_list, optimizer)
def quantize_weight(self, weight, config, **kwargs):
weight_bits = get_bits_length(config, 'weight')
def quantize_weight(self, weight, wrapper, **kwargs):
weight_bits = get_bits_length(wrapper.config, 'weight')
out = weight.tanh()
out = out / (2 * out.abs().max()) + 0.5
out = self.quantize(out, weight_bits)
......@@ -256,17 +260,17 @@ class BNNQuantizer(Quantizer):
Binarized Neural Networks: Training Deep Neural Networks with Weights and Activations Constrained to +1 or -1
(https://arxiv.org/abs/1602.02830)
"""
def __init__(self, model, config_list):
super().__init__(model, config_list)
def __init__(self, model, config_list, optimizer=None):
super().__init__(model, config_list, optimizer)
self.quant_grad = ClipGrad
def quantize_weight(self, weight, config, **kwargs):
def quantize_weight(self, weight, wrapper, **kwargs):
out = torch.sign(weight)
# remove zeros
out[out == 0] = 1
return out
def quantize_output(self, output, config, **kwargs):
def quantize_output(self, output, wrapper, **kwargs):
out = torch.sign(output)
# remove zeros
out[out == 0] = 1
......
......@@ -15,7 +15,7 @@ class WeightRankFilterPruner(Pruner):
importance criterion in convolution layers to achieve a preset level of network sparsity.
"""
def __init__(self, model, config_list):
def __init__(self, model, config_list, optimizer=None):
"""
Parameters
----------
......@@ -24,15 +24,17 @@ class WeightRankFilterPruner(Pruner):
config_list : list
support key for each list item:
- sparsity: percentage of convolutional filters to be pruned.
optimizer: torch.optim.Optimizer
Optimizer used to train model
"""
super().__init__(model, config_list)
self.register_buffer("if_calculated", torch.tensor(0)) # pylint: disable=not-callable
super().__init__(model, config_list, optimizer)
self.set_wrappers_attribute("if_calculated", False)
def get_mask(self, base_mask, weight, num_prune):
raise NotImplementedError('{} get_mask is not implemented'.format(self.__class__.__name__))
def calc_mask(self, layer, config, **kwargs):
def calc_mask(self, wrapper, **kwargs):
"""
Calculate the mask of given layer.
Filters with the smallest importance criterion of the kernel weights are masked.
......@@ -48,20 +50,21 @@ class WeightRankFilterPruner(Pruner):
dictionary for storing masks
"""
weight = layer.module.weight.data
op_type = layer.type
weight = wrapper.module.weight.data
op_type = wrapper.type
config = wrapper.config
assert 0 <= config.get('sparsity') < 1, "sparsity must in the range [0, 1)"
assert op_type in ['Conv1d', 'Conv2d'], "only support Conv1d and Conv2d"
assert op_type in config.get('op_types')
if_calculated = kwargs["if_calculated"]
if if_calculated:
if wrapper.if_calculated:
return None
mask_weight = torch.ones(weight.size()).type_as(weight).detach()
if hasattr(layer.module, 'bias') and layer.module.bias is not None:
mask_bias = torch.ones(layer.module.bias.size()).type_as(layer.module.bias).detach()
if hasattr(wrapper.module, 'bias') and wrapper.module.bias is not None:
mask_bias = torch.ones(wrapper.module.bias.size()).type_as(wrapper.module.bias).detach()
else:
mask_bias = None
mask = {'weight': mask_weight, 'bias': mask_bias}
mask = {'weight_mask': mask_weight, 'bias_mask': mask_bias}
try:
filters = weight.size(0)
num_prune = int(filters * config.get('sparsity'))
......@@ -69,7 +72,7 @@ class WeightRankFilterPruner(Pruner):
return mask
mask = self.get_mask(mask, weight, num_prune)
finally:
if_calculated.copy_(torch.tensor(1)) # pylint: disable=not-callable
wrapper.if_calculated = True
return mask
......@@ -82,7 +85,7 @@ class L1FilterPruner(WeightRankFilterPruner):
https://arxiv.org/abs/1608.08710
"""
def __init__(self, model, config_list):
def __init__(self, model, config_list, optimizer=None):
"""
Parameters
----------
......@@ -91,9 +94,11 @@ class L1FilterPruner(WeightRankFilterPruner):
config_list : list
support key for each list item:
- sparsity: percentage of convolutional filters to be pruned.
optimizer: torch.optim.Optimizer
Optimizer used to train model
"""
super().__init__(model, config_list)
super().__init__(model, config_list, optimizer)
def get_mask(self, base_mask, weight, num_prune):
"""
......@@ -121,7 +126,7 @@ class L1FilterPruner(WeightRankFilterPruner):
mask_weight = torch.gt(w_abs_structured, threshold)[:, None, None, None].expand_as(weight).type_as(weight)
mask_bias = torch.gt(w_abs_structured, threshold).type_as(weight)
return {'weight': mask_weight.detach(), 'bias': mask_bias.detach()}
return {'weight_mask': mask_weight.detach(), 'bias_mask': mask_bias.detach()}
class L2FilterPruner(WeightRankFilterPruner):
......@@ -130,7 +135,7 @@ class L2FilterPruner(WeightRankFilterPruner):
smallest L2 norm of the weights.
"""
def __init__(self, model, config_list):
def __init__(self, model, config_list, optimizer=None):
"""
Parameters
----------
......@@ -139,9 +144,11 @@ class L2FilterPruner(WeightRankFilterPruner):
config_list : list
support key for each list item:
- sparsity: percentage of convolutional filters to be pruned.
optimizer: torch.optim.Optimizer
Optimizer used to train model
"""
super().__init__(model, config_list)
super().__init__(model, config_list, optimizer)
def get_mask(self, base_mask, weight, num_prune):
"""
......@@ -167,7 +174,7 @@ class L2FilterPruner(WeightRankFilterPruner):
mask_weight = torch.gt(w_l2_norm, threshold)[:, None, None, None].expand_as(weight).type_as(weight)
mask_bias = torch.gt(w_l2_norm, threshold).type_as(weight)
return {'weight': mask_weight.detach(), 'bias': mask_bias.detach()}
return {'weight_mask': mask_weight.detach(), 'bias_mask': mask_bias.detach()}
class FPGMPruner(WeightRankFilterPruner):
......@@ -177,7 +184,7 @@ class FPGMPruner(WeightRankFilterPruner):
https://arxiv.org/pdf/1811.00250.pdf
"""
def __init__(self, model, config_list):
def __init__(self, model, config_list, optimizer):
"""
Parameters
----------
......@@ -186,8 +193,11 @@ class FPGMPruner(WeightRankFilterPruner):
config_list: list
support key for each list item:
- sparsity: percentage of convolutional filters to be pruned.
optimizer: torch.optim.Optimizer
Optimizer used to train model
"""
super().__init__(model, config_list)
super().__init__(model, config_list, optimizer)
assert isinstance(optimizer, torch.optim.Optimizer), "FPGM pruner is an iterative pruner, please pass optimizer of the model to it"
def get_mask(self, base_mask, weight, num_prune):
"""
......@@ -208,9 +218,9 @@ class FPGMPruner(WeightRankFilterPruner):
"""
min_gm_idx = self._get_min_gm_kernel_idx(weight, num_prune)
for idx in min_gm_idx:
base_mask['weight'][idx] = 0.
if base_mask['bias'] is not None:
base_mask['bias'][idx] = 0.
base_mask['weight_mask'][idx] = 0.
if base_mask['bias_mask'] is not None:
base_mask['bias_mask'][idx] = 0.
return base_mask
def _get_min_gm_kernel_idx(self, weight, n):
......@@ -258,4 +268,4 @@ class FPGMPruner(WeightRankFilterPruner):
def update_epoch(self, epoch):
for wrapper in self.get_modules_wrapper():
wrapper.registered_buffers['if_calculated'].copy_(torch.tensor(0)) # pylint: disable=not-callable
wrapper.if_calculated = False
......@@ -4,6 +4,7 @@
import logging
import datetime
from nni.assessor import Assessor, AssessResult
from nni.utils import extract_scalar_history
from .model_factory import CurveModel
logger = logging.getLogger('curvefitting_Assessor')
......@@ -91,10 +92,11 @@ class CurvefittingAssessor(Assessor):
Exception
unrecognize exception in curvefitting_assessor
"""
self.trial_history = trial_history
scalar_trial_history = extract_scalar_history(trial_history)
self.trial_history = scalar_trial_history
if not self.set_best_performance:
return AssessResult.Good
curr_step = len(trial_history)
curr_step = len(scalar_trial_history)
if curr_step < self.start_step:
return AssessResult.Good
......@@ -106,7 +108,7 @@ class CurvefittingAssessor(Assessor):
start_time = datetime.datetime.now()
# Predict the final result
curvemodel = CurveModel(self.target_pos)
predict_y = curvemodel.predict(trial_history)
predict_y = curvemodel.predict(scalar_trial_history)
logger.info('Prediction done. Trial job id = %s. Predict value = %s', trial_job_id, predict_y)
if predict_y is None:
logger.info('wait for more information to predict precisely')
......
......@@ -3,6 +3,7 @@
import logging
from nni.assessor import Assessor, AssessResult
from nni.utils import extract_scalar_history
logger = logging.getLogger('medianstop_Assessor')
......@@ -91,20 +92,12 @@ class MedianstopAssessor(Assessor):
if curr_step < self._start_step:
return AssessResult.Good
try:
num_trial_history = [float(ele) for ele in trial_history]
except (TypeError, ValueError) as error:
logger.warning('incorrect data type or value:')
logger.exception(error)
except Exception as error:
logger.warning('unrecognized exception in medianstop_assessor:')
logger.exception(error)
self._update_data(trial_job_id, num_trial_history)
scalar_trial_history = extract_scalar_history(trial_history)
self._update_data(trial_job_id, scalar_trial_history)
if self._high_better:
best_history = max(trial_history)
best_history = max(scalar_trial_history)
else:
best_history = min(trial_history)
best_history = min(scalar_trial_history)
avg_array = []
for id_ in self._completed_avg_history:
......
......@@ -234,4 +234,5 @@ class MsgDispatcher(MsgDispatcherBase):
if multi_thread_enabled():
self._handle_final_metric_data(data)
else:
data['value'] = to_json(data['value'])
self.enqueue_command(CommandType.ReportMetricData, data)
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# pylint: skip-file
# This file is copied from PyTorch 1.4, with bug fixes.
# Likely to be removed in future.
import torch
from tensorboard.compat.proto.config_pb2 import RunMetadata
from tensorboard.compat.proto.graph_pb2 import GraphDef
from tensorboard.compat.proto.step_stats_pb2 import StepStats, DeviceStepStats
from tensorboard.compat.proto.versions_pb2 import VersionDef
from torch.utils.tensorboard._pytorch_graph import GraphPy, CLASSTYPE_KIND, GETATTR_KIND, NodePyIO, NodePyOP
def parse(graph, trace, args=None, omit_useless_nodes=True):
"""This method parses an optimized PyTorch model graph and produces
a list of nodes and node stats for eventual conversion to TensorBoard
protobuf format.
Args:
graph (PyTorch module): The model graph to be parsed.
trace (PyTorch JIT TracedModule): The model trace to be parsed.
args (tuple): input tensor[s] for the model.
omit_useless_nodes (boolean): Whether to remove nodes from the graph.
"""
n_inputs = len(args)
scope = {}
nodes_py = GraphPy()
for node in graph.inputs():
if omit_useless_nodes:
if len(node.uses()) == 0: # number of user of the node (= number of outputs/ fanout)
continue
if node.type().kind() != CLASSTYPE_KIND:
nodes_py.append(NodePyIO(node, 'input'))
attr_to_scope = dict()
node_to_name = lambda d: str(d).split(":")[0].strip()
for node in graph.nodes():
if node.kind() == GETATTR_KIND:
attr_name = node.s('name')
node_name = node_to_name(node)
parent = node.input().node()
if parent.kind() == GETATTR_KIND: # If the parent node is not the top-level "self" node
parent_attr_name = parent.s('name')
parent_scope = attr_to_scope[node_to_name(parent)]
attr_scope = parent_scope.split('/')[-1]
attr_to_scope[node_name] = '{}/{}.{}'.format(parent_scope, attr_scope, attr_name)
else:
attr_to_scope[node_name] = '__module.{}'.format(attr_name)
# We don't need classtype nodes; scope will provide this information
if node.output().type().kind() != CLASSTYPE_KIND:
node_py = NodePyOP(node)
node_py.scopeName = attr_to_scope[node_name]
nodes_py.append(node_py)
else:
nodes_py.append(NodePyOP(node))
for i, node in enumerate(graph.outputs()): # Create sink nodes for output ops
node_py = NodePyIO(node, 'output')
node_py.debugName = "output.{}".format(i + 1)
node_py.inputs = [node.debugName()]
nodes_py.append(node_py)
def parse_traced_name(module_name):
prefix = 'TracedModule['
suffix = ']'
if module_name.startswith(prefix) and module_name.endswith(suffix):
module_name = module_name[len(prefix):-len(suffix)]
return module_name
alias_to_name = dict()
base_name = parse_traced_name(trace._name)
for name, module in trace.named_modules(prefix='__module'):
mod_name = parse_traced_name(module._name)
attr_name = name.split('.')[-1]
alias_to_name[name] = '{}[{}]'.format(mod_name, attr_name)
for node in nodes_py.nodes_op:
module_aliases = node.scopeName.split('/')[-1].split('.')
module_name = ''
for i, alias in enumerate(module_aliases):
if i == 0:
module_name = alias
node.scopeName = base_name
else:
module_name += '.' + alias
node.scopeName += '/' + (alias_to_name[module_name] if module_name in alias_to_name else alias)
nodes_py.populate_namespace_from_OP_to_IO()
return nodes_py.to_proto()
def graph(model, args, verbose=False):
"""
This method processes a PyTorch model and produces a `GraphDef` proto
that can be logged to TensorBoard.
Args:
model (PyTorch module): The model to be parsed.
args (tuple): input tensor[s] for the model.
verbose (bool): Whether to print out verbose information while
processing.
"""
with torch.onnx.set_training(model, False): # TODO: move outside of torch.onnx?
try:
trace = torch.jit.trace(model, args)
graph = trace.graph
torch._C._jit_pass_inline(graph)
except RuntimeError as e:
print(e)
print('Error occurs, No graph saved')
raise e
if verbose:
print(graph)
list_of_nodes = parse(graph, trace, args)
# We are hardcoding that this was run on CPU even though it might have actually
# run on GPU. Note this is what is shown in TensorBoard and has no bearing
# on actual execution.
# TODO: See if we can extract GPU vs CPU information from the PyTorch model
# and pass it correctly to TensorBoard.
#
# Definition of StepStats and DeviceStepStats can be found at
# https://github.com/tensorflow/tensorboard/blob/master/tensorboard/plugins/graph/tf_graph_common/test/graph-test.ts
# and
# https://github.com/tensorflow/tensorboard/blob/master/tensorboard/compat/proto/step_stats.proto
stepstats = RunMetadata(step_stats=StepStats(dev_stats=[DeviceStepStats(device="/device:CPU:0")]))
return GraphDef(node=list_of_nodes, versions=VersionDef(producer=22)), stepstats
# The producer version has been reverse engineered from standard
# TensorBoard logged data.
......@@ -23,7 +23,9 @@ class StackedLSTMCell(nn.Module):
curr_c, curr_h = m(inputs, (prev_c[i], prev_h[i]))
next_c.append(curr_c)
next_h.append(curr_h)
inputs = curr_h[-1]
# current implementation only supports batch size equals 1,
# but the algorithm does not necessarily have this limitation
inputs = curr_h[-1].view(1, -1)
return next_c, next_h
......
......@@ -2,7 +2,9 @@
# Licensed under the MIT license.
import logging
from collections import defaultdict
import numpy as np
import torch
from nni.nas.pytorch.base_mutator import BaseMutator
......@@ -15,6 +17,7 @@ class Mutator(BaseMutator):
def __init__(self, model):
super().__init__(model)
self._cache = dict()
self._connect_all = False
def sample_search(self):
"""
......@@ -57,6 +60,74 @@ class Mutator(BaseMutator):
"""
return self.sample_final()
def status(self):
"""
Return current selection status of mutator.
Returns
-------
dict
A mapping from key of mutables to decisions. All weights (boolean type and float type)
are converted into real number values. Numpy arrays and tensors are converted into list.
"""
data = dict()
for k, v in self._cache.items():
if torch.is_tensor(v):
v = v.detach().cpu().numpy()
if isinstance(v, np.ndarray):
v = v.astype(np.float32).tolist()
data[k] = v
return data
def graph(self, inputs):
"""
Return model supernet graph.
Parameters
----------
inputs: tuple of tensor
Inputs that will be feeded into the network.
Returns
-------
dict
Containing ``node``, in Tensorboard GraphDef format.
Additional key ``mutable`` is a map from key to list of modules.
"""
if not torch.__version__.startswith("1.4"):
logger.warning("Graph is only tested with PyTorch 1.4. Other versions might not work.")
from ._graph_utils import graph
from google.protobuf import json_format
# protobuf should be installed as long as tensorboard is installed
try:
self._connect_all = True
graph_def, _ = graph(self.model, inputs, verbose=False)
result = json_format.MessageToDict(graph_def)
finally:
self._connect_all = False
# `mutable` is to map the keys to a list of corresponding modules.
# A key can be linked to multiple modules, use `dedup=False` to find them all.
result["mutable"] = defaultdict(list)
for mutable in self.mutables.traverse(deduplicate=False):
# A module will be represent in the format of
# [{"type": "Net", "name": ""}, {"type": "Cell", "name": "cell1"}, {"type": "Conv2d": "name": "conv"}]
# which will be concatenated into Net/Cell[cell1]/Conv2d[conv] in frontend.
# This format is aligned with the scope name jit gives.
modules = mutable.name.split(".")
path = [
{"type": self.model.__class__.__name__, "name": ""}
]
m = self.model
for module in modules:
m = getattr(m, module)
path.append({
"type": m.__class__.__name__,
"name": module
})
result["mutable"][mutable.key].append(path)
return result
def on_forward_layer_choice(self, mutable, *inputs):
"""
On default, this method retrieves the decision obtained previously, and select certain operations.
......@@ -75,6 +146,11 @@ class Mutator(BaseMutator):
tuple of torch.Tensor and torch.Tensor
Output and mask.
"""
if self._connect_all:
return self._all_connect_tensor_reduction(mutable.reduction,
[op(*inputs) for op in mutable.choices]), \
torch.ones(mutable.length)
def _map_fn(op, *inputs):
return op(*inputs)
......@@ -101,6 +177,9 @@ class Mutator(BaseMutator):
tuple of torch.Tensor and torch.Tensor
Output and mask.
"""
if self._connect_all:
return self._all_connect_tensor_reduction(mutable.reduction, tensor_list), \
torch.ones(mutable.n_candidates)
mask = self._get_decision(mutable)
assert len(mask) == mutable.n_candidates, \
"Invalid mask, expected {} to be of length {}.".format(mask, mutable.n_candidates)
......@@ -131,6 +210,13 @@ class Mutator(BaseMutator):
return torch.cat(tensor_list, dim=1)
raise ValueError("Unrecognized reduction policy: \"{}\"".format(reduction_type))
def _all_connect_tensor_reduction(self, reduction_type, tensor_list):
if reduction_type == "none":
return tensor_list
if reduction_type == "concat":
return torch.cat(tensor_list, dim=1)
return torch.stack(tensor_list).sum(0)
def _get_decision(self, mutable):
"""
By default, this method checks whether `mutable.key` is already in the decision cache,
......
......@@ -9,7 +9,7 @@ if trial_env_vars.NNI_PLATFORM is None:
from .standalone import *
elif trial_env_vars.NNI_PLATFORM == 'unittest':
from .test import *
elif trial_env_vars.NNI_PLATFORM in ('local', 'remote', 'pai', 'kubeflow', 'frameworkcontroller', 'paiYarn'):
elif trial_env_vars.NNI_PLATFORM in ('local', 'remote', 'pai', 'kubeflow', 'frameworkcontroller', 'paiYarn', 'dlts'):
from .local import *
else:
raise RuntimeError('Unknown platform %s' % trial_env_vars.NNI_PLATFORM)
......@@ -62,6 +62,13 @@ def extract_scalar_reward(value, scalar_key='default'):
"""
Extract scalar reward from trial result.
Parameters
----------
value : int, float, dict
the reported final metric data
scalar_key : str
the key name that indicates the numeric number
Raises
------
RuntimeError
......@@ -78,6 +85,26 @@ def extract_scalar_reward(value, scalar_key='default'):
return reward
def extract_scalar_history(trial_history, scalar_key='default'):
"""
Extract scalar value from a list of intermediate results.
Parameters
----------
trial_history : list
accumulated intermediate results of a trial
scalar_key : str
the key name that indicates the numeric number
Raises
------
RuntimeError
Incorrect final result: the final result should be float/int,
or a dict which has a key named "default" whose value is float/int.
"""
return [extract_scalar_reward(ele, scalar_key) for ele in trial_history]
def convert_dict2tuple(value):
"""
convert dict type to tuple to solve unhashable problem.
......@@ -90,7 +117,9 @@ def convert_dict2tuple(value):
def init_dispatcher_logger():
""" Initialize dispatcher logging configuration"""
"""
Initialize dispatcher logging configuration
"""
logger_file_path = 'dispatcher.log'
if dispatcher_env_vars.NNI_LOG_DIRECTORY is not None:
logger_file_path = os.path.join(dispatcher_env_vars.NNI_LOG_DIRECTORY, logger_file_path)
......
This diff is collapsed.
......@@ -59,29 +59,34 @@ class Customize extends React.Component<CustomizeProps, CustomizeState> {
});
// true: parameters are wrong
let flag = false;
let parametersIllegal = false;
Object.keys(customized).map(item => {
if (item !== 'tag') {
// unified data type
if (typeof copyTrialParameter[item] === 'number' && typeof customized[item] === 'string') {
customized[item] = JSON.parse(customized[item]);
}
if (searchSpace[item] === undefined) {
// sometimes the schema of trial parameters is different from search space
// e.g. Batch Tuner
return;
}
if (searchSpace[item]._type === 'choice') {
if (searchSpace[item]._value.find((val: string | number) =>
val === customized[item]) === undefined) {
flag = true;
parametersIllegal = true;
return;
}
} else {
if (customized[item] < searchSpace[item]._value[0]
|| customized[item] > searchSpace[item]._value[1]) {
flag = true;
parametersIllegal = true;
return;
}
}
}
});
if (flag !== false) {
if (parametersIllegal !== false) {
// open the warning modal
this.setState(() => ({ isShowWarning: true, customParameters: customized }));
} else {
......@@ -269,4 +274,4 @@ class Customize extends React.Component<CustomizeProps, CustomizeState> {
}
}
export default Customize;
\ No newline at end of file
export default Customize;
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment