"...Accuracy_Validation/ResNet50_Official/README.md" did not exist on "0e04b692e6f879d1641a890cb3b32913d9e341c8"
Unverified Commit 1a5c0172 authored by SparkSnail's avatar SparkSnail Committed by GitHub
Browse files

Merge pull request #224 from microsoft/master

merge master
parents b9a7a95d ae81ec47
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import copy
import logging
import torch
from .compressor import Pruner
_logger = logging.getLogger(__name__)
__all__ = ['LevelPruner', 'AGP_Pruner', 'SlimPruner', 'LotteryTicketPruner']
logger = logging.getLogger('torch pruner')
class LevelPruner(Pruner):
"""
Prune to an exact pruning level specification
"""
def __init__(self, model, config_list):
"""
Parameters
----------
model : torch.nn.module
Model to be pruned
config_list : list
List on pruning configs
"""
super().__init__(model, config_list)
self.mask_calculated_ops = set()
def calc_mask(self, layer, config):
"""
Calculate the mask of given layer
Parameters
----------
layer : LayerInfo
the layer to instrument the compression operation
config : dict
layer's pruning config
Returns
-------
dict
dictionary for storing masks
"""
weight = layer.module.weight.data
op_name = layer.name
if op_name not in self.mask_calculated_ops:
w_abs = weight.abs()
k = int(weight.numel() * config['sparsity'])
if k == 0:
return torch.ones(weight.shape).type_as(weight)
threshold = torch.topk(w_abs.view(-1), k, largest=False)[0].max()
mask_weight = torch.gt(w_abs, threshold).type_as(weight)
mask = {'weight': mask_weight}
self.mask_dict.update({op_name: mask})
self.mask_calculated_ops.add(op_name)
else:
assert op_name in self.mask_dict, "op_name not in the mask_dict"
mask = self.mask_dict[op_name]
return mask
class AGP_Pruner(Pruner):
"""
An automated gradual pruning algorithm that prunes the smallest magnitude
weights to achieve a preset level of network sparsity.
Michael Zhu and Suyog Gupta, "To prune, or not to prune: exploring the
efficacy of pruning for model compression", 2017 NIPS Workshop on Machine
Learning of Phones and other Consumer Devices,
https://arxiv.org/pdf/1710.01878.pdf
"""
def __init__(self, model, config_list):
"""
Parameters
----------
model : torch.nn.module
Model to be pruned
config_list : list
List on pruning configs
"""
super().__init__(model, config_list)
self.now_epoch = 0
self.if_init_list = {}
def calc_mask(self, layer, config):
"""
Calculate the mask of given layer
Parameters
----------
layer : LayerInfo
the layer to instrument the compression operation
config : dict
layer's pruning config
Returns
-------
dict
dictionary for storing masks
"""
weight = layer.module.weight.data
op_name = layer.name
start_epoch = config.get('start_epoch', 0)
freq = config.get('frequency', 1)
if self.now_epoch >= start_epoch and self.if_init_list.get(op_name, True) \
and (self.now_epoch - start_epoch) % freq == 0:
mask = self.mask_dict.get(op_name, {'weight': torch.ones(weight.shape).type_as(weight)})
target_sparsity = self.compute_target_sparsity(config)
k = int(weight.numel() * target_sparsity)
if k == 0 or target_sparsity >= 1 or target_sparsity <= 0:
return mask
# if we want to generate new mask, we should update weigth first
w_abs = weight.abs() * mask
threshold = torch.topk(w_abs.view(-1), k, largest=False)[0].max()
new_mask = {'weight': torch.gt(w_abs, threshold).type_as(weight)}
self.mask_dict.update({op_name: new_mask})
self.if_init_list.update({op_name: False})
else:
new_mask = self.mask_dict.get(op_name, {'weight': torch.ones(weight.shape).type_as(weight)})
return new_mask
def compute_target_sparsity(self, config):
"""
Calculate the sparsity for pruning
Parameters
----------
config : dict
Layer's pruning config
Returns
-------
float
Target sparsity to be pruned
"""
end_epoch = config.get('end_epoch', 1)
start_epoch = config.get('start_epoch', 0)
freq = config.get('frequency', 1)
final_sparsity = config.get('final_sparsity', 0)
initial_sparsity = config.get('initial_sparsity', 0)
if end_epoch <= start_epoch or initial_sparsity >= final_sparsity:
logger.warning('your end epoch <= start epoch or initial_sparsity >= final_sparsity')
return final_sparsity
if end_epoch <= self.now_epoch:
return final_sparsity
span = ((end_epoch - start_epoch - 1) // freq) * freq
assert span > 0
target_sparsity = (final_sparsity +
(initial_sparsity - final_sparsity) *
(1.0 - ((self.now_epoch - start_epoch) / span)) ** 3)
return target_sparsity
def update_epoch(self, epoch):
"""
Update epoch
Parameters
----------
epoch : int
current training epoch
"""
if epoch > 0:
self.now_epoch = epoch
for k in self.if_init_list.keys():
self.if_init_list[k] = True
class SlimPruner(Pruner):
"""
A structured pruning algorithm that prunes channels by pruning the weights of BN layers.
Zhuang Liu, Jianguo Li, Zhiqiang Shen, Gao Huang, Shoumeng Yan and Changshui Zhang
"Learning Efficient Convolutional Networks through Network Slimming", 2017 ICCV
https://arxiv.org/pdf/1708.06519.pdf
"""
def __init__(self, model, config_list):
"""
Parameters
----------
config_list : list
support key for each list item:
- sparsity: percentage of convolutional filters to be pruned.
"""
super().__init__(model, config_list)
self.mask_calculated_ops = set()
weight_list = []
if len(config_list) > 1:
logger.warning('Slim pruner only supports 1 configuration')
config = config_list[0]
for (layer, config) in self.detect_modules_to_compress():
assert layer.type == 'BatchNorm2d', 'SlimPruner only supports 2d batch normalization layer pruning'
weight_list.append(layer.module.weight.data.abs().clone())
all_bn_weights = torch.cat(weight_list)
k = int(all_bn_weights.shape[0] * config['sparsity'])
self.global_threshold = torch.topk(all_bn_weights.view(-1), k, largest=False)[0].max()
def calc_mask(self, layer, config):
"""
Calculate the mask of given layer.
Scale factors with the smallest absolute value in the BN layer are masked.
Parameters
----------
layer : LayerInfo
the layer to instrument the compression operation
config : dict
layer's pruning config
Returns
-------
dict
dictionary for storing masks
"""
weight = layer.module.weight.data
op_name = layer.name
op_type = layer.type
assert op_type == 'BatchNorm2d', 'SlimPruner only supports 2d batch normalization layer pruning'
if op_name in self.mask_calculated_ops:
assert op_name in self.mask_dict
return self.mask_dict.get(op_name)
base_mask = torch.ones(weight.size()).type_as(weight).detach()
mask = {'weight': base_mask.detach(), 'bias': base_mask.clone().detach()}
try:
filters = weight.size(0)
num_prune = int(filters * config.get('sparsity'))
if filters < 2 or num_prune < 1:
return mask
w_abs = weight.abs()
mask_weight = torch.gt(w_abs, self.global_threshold).type_as(weight)
mask_bias = mask_weight.clone()
mask = {'weight': mask_weight.detach(), 'bias': mask_bias.detach()}
finally:
self.mask_dict.update({layer.name: mask})
self.mask_calculated_ops.add(layer.name)
return mask
class LotteryTicketPruner(Pruner):
"""
......
......@@ -5,7 +5,7 @@ import logging
import torch
from .compressor import Quantizer, QuantGrad, QuantType
__all__ = ['NaiveQuantizer', 'QAT_Quantizer', 'DoReFaQuantizer']
__all__ = ['NaiveQuantizer', 'QAT_Quantizer', 'DoReFaQuantizer', 'BNNQuantizer']
logger = logging.getLogger(__name__)
......
......@@ -5,240 +5,9 @@ import logging
import torch
from .compressor import Pruner
__all__ = ['LevelPruner', 'AGP_Pruner', 'SlimPruner', 'L1FilterPruner', 'L2FilterPruner', 'FPGMPruner',
'ActivationAPoZRankFilterPruner', 'ActivationMeanRankFilterPruner']
logger = logging.getLogger('torch pruner')
class LevelPruner(Pruner):
"""
Prune to an exact pruning level specification
"""
def __init__(self, model, config_list):
"""
Parameters
----------
model : torch.nn.module
Model to be pruned
config_list : list
List on pruning configs
"""
super().__init__(model, config_list)
self.mask_calculated_ops = set()
def calc_mask(self, layer, config):
"""
Calculate the mask of given layer
Parameters
----------
layer : LayerInfo
the layer to instrument the compression operation
config : dict
layer's pruning config
Returns
-------
dict
dictionary for storing masks
"""
weight = layer.module.weight.data
op_name = layer.name
if op_name not in self.mask_calculated_ops:
w_abs = weight.abs()
k = int(weight.numel() * config['sparsity'])
if k == 0:
return torch.ones(weight.shape).type_as(weight)
threshold = torch.topk(w_abs.view(-1), k, largest=False)[0].max()
mask_weight = torch.gt(w_abs, threshold).type_as(weight)
mask = {'weight': mask_weight}
self.mask_dict.update({op_name: mask})
self.mask_calculated_ops.add(op_name)
else:
assert op_name in self.mask_dict, "op_name not in the mask_dict"
mask = self.mask_dict[op_name]
return mask
class AGP_Pruner(Pruner):
"""
An automated gradual pruning algorithm that prunes the smallest magnitude
weights to achieve a preset level of network sparsity.
Michael Zhu and Suyog Gupta, "To prune, or not to prune: exploring the
efficacy of pruning for model compression", 2017 NIPS Workshop on Machine
Learning of Phones and other Consumer Devices,
https://arxiv.org/pdf/1710.01878.pdf
"""
def __init__(self, model, config_list):
"""
Parameters
----------
model : torch.nn.module
Model to be pruned
config_list : list
List on pruning configs
"""
super().__init__(model, config_list)
self.now_epoch = 0
self.if_init_list = {}
def calc_mask(self, layer, config):
"""
Calculate the mask of given layer
Parameters
----------
layer : LayerInfo
the layer to instrument the compression operation
config : dict
layer's pruning config
Returns
-------
dict
dictionary for storing masks
"""
weight = layer.module.weight.data
op_name = layer.name
start_epoch = config.get('start_epoch', 0)
freq = config.get('frequency', 1)
if self.now_epoch >= start_epoch and self.if_init_list.get(op_name, True) \
and (self.now_epoch - start_epoch) % freq == 0:
mask = self.mask_dict.get(op_name, {'weight': torch.ones(weight.shape).type_as(weight)})
target_sparsity = self.compute_target_sparsity(config)
k = int(weight.numel() * target_sparsity)
if k == 0 or target_sparsity >= 1 or target_sparsity <= 0:
return mask
# if we want to generate new mask, we should update weigth first
w_abs = weight.abs() * mask
threshold = torch.topk(w_abs.view(-1), k, largest=False)[0].max()
new_mask = {'weight': torch.gt(w_abs, threshold).type_as(weight)}
self.mask_dict.update({op_name: new_mask})
self.if_init_list.update({op_name: False})
else:
new_mask = self.mask_dict.get(op_name, {'weight': torch.ones(weight.shape).type_as(weight)})
return new_mask
def compute_target_sparsity(self, config):
"""
Calculate the sparsity for pruning
Parameters
----------
config : dict
Layer's pruning config
Returns
-------
float
Target sparsity to be pruned
"""
end_epoch = config.get('end_epoch', 1)
start_epoch = config.get('start_epoch', 0)
freq = config.get('frequency', 1)
final_sparsity = config.get('final_sparsity', 0)
initial_sparsity = config.get('initial_sparsity', 0)
if end_epoch <= start_epoch or initial_sparsity >= final_sparsity:
logger.warning('your end epoch <= start epoch or initial_sparsity >= final_sparsity')
return final_sparsity
if end_epoch <= self.now_epoch:
return final_sparsity
span = ((end_epoch - start_epoch - 1) // freq) * freq
assert span > 0
target_sparsity = (final_sparsity +
(initial_sparsity - final_sparsity) *
(1.0 - ((self.now_epoch - start_epoch) / span)) ** 3)
return target_sparsity
def update_epoch(self, epoch):
"""
Update epoch
Parameters
----------
epoch : int
current training epoch
"""
if epoch > 0:
self.now_epoch = epoch
for k in self.if_init_list.keys():
self.if_init_list[k] = True
class SlimPruner(Pruner):
"""
A structured pruning algorithm that prunes channels by pruning the weights of BN layers.
Zhuang Liu, Jianguo Li, Zhiqiang Shen, Gao Huang, Shoumeng Yan and Changshui Zhang
"Learning Efficient Convolutional Networks through Network Slimming", 2017 ICCV
https://arxiv.org/pdf/1708.06519.pdf
"""
def __init__(self, model, config_list):
"""
Parameters
----------
config_list : list
support key for each list item:
- sparsity: percentage of convolutional filters to be pruned.
"""
super().__init__(model, config_list)
self.mask_calculated_ops = set()
weight_list = []
if len(config_list) > 1:
logger.warning('Slim pruner only supports 1 configuration')
config = config_list[0]
for (layer, config) in self.detect_modules_to_compress():
assert layer.type == 'BatchNorm2d', 'SlimPruner only supports 2d batch normalization layer pruning'
weight_list.append(layer.module.weight.data.abs().clone())
all_bn_weights = torch.cat(weight_list)
k = int(all_bn_weights.shape[0] * config['sparsity'])
self.global_threshold = torch.topk(all_bn_weights.view(-1), k, largest=False)[0].max()
def calc_mask(self, layer, config):
"""
Calculate the mask of given layer.
Scale factors with the smallest absolute value in the BN layer are masked.
Parameters
----------
layer : LayerInfo
the layer to instrument the compression operation
config : dict
layer's pruning config
Returns
-------
dict
dictionary for storing masks
"""
weight = layer.module.weight.data
op_name = layer.name
op_type = layer.type
assert op_type == 'BatchNorm2d', 'SlimPruner only supports 2d batch normalization layer pruning'
if op_name in self.mask_calculated_ops:
assert op_name in self.mask_dict
return self.mask_dict.get(op_name)
base_mask = torch.ones(weight.size()).type_as(weight).detach()
mask = {'weight': base_mask.detach(), 'bias': base_mask.clone().detach()}
try:
filters = weight.size(0)
num_prune = int(filters * config.get('sparsity'))
if filters < 2 or num_prune < 1:
return mask
w_abs = weight.abs()
mask_weight = torch.gt(w_abs, self.global_threshold).type_as(weight)
mask_bias = mask_weight.clone()
mask = {'weight': mask_weight.detach(), 'bias': mask_bias.detach()}
finally:
self.mask_dict.update({layer.name: mask})
self.mask_calculated_ops.add(layer.name)
return mask
__all__ = ['L1FilterPruner', 'L2FilterPruner', 'FPGMPruner']
logger = logging.getLogger('torch weight rank filter pruners')
class WeightRankFilterPruner(Pruner):
"""
......@@ -260,8 +29,8 @@ class WeightRankFilterPruner(Pruner):
super().__init__(model, config_list)
self.mask_calculated_ops = set() # operations whose mask has been calculated
def _get_mask(self, base_mask, weight, num_prune):
return {'weight': None, 'bias': None}
def get_mask(self, base_mask, weight, num_prune):
raise NotImplementedError('{} get_mask is not implemented'.format(self.__class__.__name__))
def calc_mask(self, layer, config):
"""
......@@ -299,7 +68,7 @@ class WeightRankFilterPruner(Pruner):
num_prune = int(filters * config.get('sparsity'))
if filters < 2 or num_prune < 1:
return mask
mask = self._get_mask(mask, weight, num_prune)
mask = self.get_mask(mask, weight, num_prune)
finally:
self.mask_dict.update({op_name: mask})
self.mask_calculated_ops.add(op_name)
......@@ -328,7 +97,7 @@ class L1FilterPruner(WeightRankFilterPruner):
super().__init__(model, config_list)
def _get_mask(self, base_mask, weight, num_prune):
def get_mask(self, base_mask, weight, num_prune):
"""
Calculate the mask of given layer.
Filters with the smallest sum of its absolute kernel weights are masked.
......@@ -376,7 +145,7 @@ class L2FilterPruner(WeightRankFilterPruner):
super().__init__(model, config_list)
def _get_mask(self, base_mask, weight, num_prune):
def get_mask(self, base_mask, weight, num_prune):
"""
Calculate the mask of given layer.
Filters with the smallest L2 norm of the absolute kernel weights are masked.
......@@ -422,7 +191,7 @@ class FPGMPruner(WeightRankFilterPruner):
"""
super().__init__(model, config_list)
def _get_mask(self, base_mask, weight, num_prune):
def get_mask(self, base_mask, weight, num_prune):
"""
Calculate the mask of given layer.
Filters with the smallest sum of its absolute kernel weights are masked.
......@@ -491,251 +260,3 @@ class FPGMPruner(WeightRankFilterPruner):
def update_epoch(self, epoch):
self.mask_calculated_ops = set()
class ActivationRankFilterPruner(Pruner):
"""
A structured pruning base class that prunes the filters with the smallest
importance criterion in convolution layers to achieve a preset level of network sparsity.
Hengyuan Hu, Rui Peng, Yu-Wing Tai and Chi-Keung Tang,
"Network Trimming: A Data-Driven Neuron Pruning Approach towards Efficient Deep Architectures", ICLR 2016.
https://arxiv.org/abs/1607.03250
Pavlo Molchanov, Stephen Tyree, Tero Karras, Timo Aila and Jan Kautz,
"Pruning Convolutional Neural Networks for Resource Efficient Inference", ICLR 2017.
https://arxiv.org/abs/1611.06440
"""
def __init__(self, model, config_list, activation='relu', statistics_batch_num=1):
"""
Parameters
----------
model : torch.nn.module
Model to be pruned
config_list : list
support key for each list item:
- sparsity: percentage of convolutional filters to be pruned.
activation : str
Activation function
statistics_batch_num : int
Num of batches for activation statistics
"""
super().__init__(model, config_list)
self.mask_calculated_ops = set()
self.statistics_batch_num = statistics_batch_num
self.collected_activation = {}
self.hooks = {}
assert activation in ['relu', 'relu6']
if activation == 'relu':
self.activation = torch.nn.functional.relu
elif activation == 'relu6':
self.activation = torch.nn.functional.relu6
else:
self.activation = None
def compress(self):
"""
Compress the model, register a hook for collecting activations.
"""
modules_to_compress = self.detect_modules_to_compress()
for layer, config in modules_to_compress:
self._instrument_layer(layer, config)
self.collected_activation[layer.name] = []
def _hook(module_, input_, output, name=layer.name):
if len(self.collected_activation[name]) < self.statistics_batch_num:
self.collected_activation[name].append(self.activation(output.detach().cpu()))
layer.module.register_forward_hook(_hook)
return self.bound_model
def _get_mask(self, base_mask, activations, num_prune):
return {'weight': None, 'bias': None}
def calc_mask(self, layer, config):
"""
Calculate the mask of given layer.
Filters with the smallest importance criterion which is calculated from the activation are masked.
Parameters
----------
layer : LayerInfo
the layer to instrument the compression operation
config : dict
layer's pruning config
Returns
-------
dict
dictionary for storing masks
"""
weight = layer.module.weight.data
op_name = layer.name
op_type = layer.type
assert 0 <= config.get('sparsity') < 1, "sparsity must in the range [0, 1)"
assert op_type in ['Conv2d'], "only support Conv2d"
assert op_type in config.get('op_types')
if op_name in self.mask_calculated_ops:
assert op_name in self.mask_dict
return self.mask_dict.get(op_name)
mask_weight = torch.ones(weight.size()).type_as(weight).detach()
if hasattr(layer.module, 'bias') and layer.module.bias is not None:
mask_bias = torch.ones(layer.module.bias.size()).type_as(layer.module.bias).detach()
else:
mask_bias = None
mask = {'weight': mask_weight, 'bias': mask_bias}
try:
filters = weight.size(0)
num_prune = int(filters * config.get('sparsity'))
if filters < 2 or num_prune < 1 or len(self.collected_activation[layer.name]) < self.statistics_batch_num:
return mask
mask = self._get_mask(mask, self.collected_activation[layer.name], num_prune)
finally:
if len(self.collected_activation[layer.name]) == self.statistics_batch_num:
self.mask_dict.update({op_name: mask})
self.mask_calculated_ops.add(op_name)
return mask
class ActivationAPoZRankFilterPruner(ActivationRankFilterPruner):
"""
A structured pruning algorithm that prunes the filters with the
smallest APoZ(average percentage of zeros) of output activations.
Hengyuan Hu, Rui Peng, Yu-Wing Tai and Chi-Keung Tang,
"Network Trimming: A Data-Driven Neuron Pruning Approach towards Efficient Deep Architectures", ICLR 2016.
https://arxiv.org/abs/1607.03250
"""
def __init__(self, model, config_list, activation='relu', statistics_batch_num=1):
"""
Parameters
----------
model : torch.nn.module
Model to be pruned
config_list : list
support key for each list item:
- sparsity: percentage of convolutional filters to be pruned.
activation : str
Activation function
statistics_batch_num : int
Num of batches for activation statistics
"""
super().__init__(model, config_list, activation, statistics_batch_num)
def _get_mask(self, base_mask, activations, num_prune):
"""
Calculate the mask of given layer.
Filters with the smallest APoZ(average percentage of zeros) of output activations are masked.
Parameters
----------
base_mask : dict
The basic mask with the same shape of weight, all item in the basic mask is 1.
activations : list
Layer's output activations
num_prune : int
Num of filters to prune
Returns
-------
dict
dictionary for storing masks
"""
apoz = self._calc_apoz(activations)
prune_indices = torch.argsort(apoz, descending=True)[:num_prune]
for idx in prune_indices:
base_mask['weight'][idx] = 0.
if base_mask['bias'] is not None:
base_mask['bias'][idx] = 0.
return base_mask
def _calc_apoz(self, activations):
"""
Calculate APoZ(average percentage of zeros) of activations.
Parameters
----------
activations : list
Layer's output activations
Returns
-------
torch.Tensor
Filter's APoZ(average percentage of zeros) of the activations
"""
activations = torch.cat(activations, 0)
_eq_zero = torch.eq(activations, torch.zeros_like(activations))
_apoz = torch.sum(_eq_zero, dim=(0, 2, 3)) / torch.numel(_eq_zero[:, 0, :, :])
return _apoz
class ActivationMeanRankFilterPruner(ActivationRankFilterPruner):
"""
A structured pruning algorithm that prunes the filters with the
smallest mean value of output activations.
Pavlo Molchanov, Stephen Tyree, Tero Karras, Timo Aila and Jan Kautz,
"Pruning Convolutional Neural Networks for Resource Efficient Inference", ICLR 2017.
https://arxiv.org/abs/1611.06440
"""
def __init__(self, model, config_list, activation='relu', statistics_batch_num=1):
"""
Parameters
----------
model : torch.nn.module
Model to be pruned
config_list : list
support key for each list item:
- sparsity: percentage of convolutional filters to be pruned.
activation : str
Activation function
statistics_batch_num : int
Num of batches for activation statistics
"""
super().__init__(model, config_list, activation, statistics_batch_num)
def _get_mask(self, base_mask, activations, num_prune):
"""
Calculate the mask of given layer.
Filters with the smallest APoZ(average percentage of zeros) of output activations are masked.
Parameters
----------
base_mask : dict
The basic mask with the same shape of weight, all item in the basic mask is 1.
activations : list
Layer's output activations
num_prune : int
Num of filters to prune
Returns
-------
dict
dictionary for storing masks
"""
mean_activation = self._cal_mean_activation(activations)
prune_indices = torch.argsort(mean_activation)[:num_prune]
for idx in prune_indices:
base_mask['weight'][idx] = 0.
if base_mask['bias'] is not None:
base_mask['bias'][idx] = 0.
return base_mask
def _cal_mean_activation(self, activations):
"""
Calculate mean value of activations.
Parameters
----------
activations : list
Layer's output activations
Returns
-------
torch.Tensor
Filter's mean value of the output activations
"""
activations = torch.cat(activations, 0)
mean_activation = torch.mean(activations, dim=(0, 2, 3))
return mean_activation
......@@ -10,7 +10,7 @@ import torch
import nni
from nni.env_vars import trial_env_vars
from nni.nas.pytorch.mutables import LayerChoice, InputChoice
from nni.nas.pytorch.mutables import LayerChoice, InputChoice, MutableScope
from nni.nas.pytorch.mutator import Mutator
logger = logging.getLogger(__name__)
......@@ -104,10 +104,11 @@ class ClassicMutator(Mutator):
search_space_item : list
The list for corresponding search space.
"""
candidate_repr = search_space_item["candidates"]
multihot_list = [False] * mutable.n_candidates
for i, v in zip(idx, value):
assert 0 <= i < mutable.n_candidates and search_space_item[i] == v, \
"Index '{}' in search space '{}' is not '{}'".format(i, search_space_item, v)
assert 0 <= i < mutable.n_candidates and candidate_repr[i] == v, \
"Index '{}' in search space '{}' is not '{}'".format(i, candidate_repr, v)
assert not multihot_list[i], "'{}' is selected twice in '{}', which is not allowed.".format(i, idx)
multihot_list[i] = True
return torch.tensor(multihot_list, dtype=torch.bool) # pylint: disable=not-callable
......@@ -121,17 +122,20 @@ class ClassicMutator(Mutator):
self._chosen_arch.keys())
result = dict()
for mutable in self.mutables:
assert mutable.key in self._chosen_arch, "Expected '{}' in chosen arch, but not found.".format(mutable.key)
data = self._chosen_arch[mutable.key]
assert isinstance(data, dict) and "_value" in data and "_idx" in data, \
"'{}' is not a valid choice.".format(data)
value = data["_value"]
idx = data["_idx"]
search_space_item = self._search_space[mutable.key]["_value"]
if isinstance(mutable, (LayerChoice, InputChoice)):
assert mutable.key in self._chosen_arch, \
"Expected '{}' in chosen arch, but not found.".format(mutable.key)
data = self._chosen_arch[mutable.key]
assert isinstance(data, dict) and "_value" in data and "_idx" in data, \
"'{}' is not a valid choice.".format(data)
if isinstance(mutable, LayerChoice):
result[mutable.key] = self._sample_layer_choice(mutable, idx, value, search_space_item)
result[mutable.key] = self._sample_layer_choice(mutable, data["_idx"], data["_value"],
self._search_space[mutable.key]["_value"])
elif isinstance(mutable, InputChoice):
result[mutable.key] = self._sample_input_choice(mutable, idx, value, search_space_item)
result[mutable.key] = self._sample_input_choice(mutable, data["_idx"], data["_value"],
self._search_space[mutable.key]["_value"])
elif isinstance(mutable, MutableScope):
logger.info("Mutable scope '%s' is skipped during parsing choices.", mutable.key)
else:
raise TypeError("Unsupported mutable type: '%s'." % type(mutable))
return result
......@@ -190,6 +194,8 @@ class ClassicMutator(Mutator):
search_space[key] = {"_type": INPUT_CHOICE,
"_value": {"candidates": mutable.choose_from,
"n_chosen": mutable.n_chosen}}
elif isinstance(mutable, MutableScope):
logger.info("Mutable scope '%s' is skipped during generating search space.", mutable.key)
else:
raise TypeError("Unsupported mutable type: '%s'." % type(mutable))
return search_space
......
......@@ -14,6 +14,26 @@ _logger = logging.getLogger(__name__)
class DartsMutator(Mutator):
"""
Connects the model in a DARTS (differentiable) way.
An extra connection is automatically inserted for each LayerChoice, when this connection is selected, there is no
op on this LayerChoice (namely a ``ZeroOp``), in which case, every element in the exported choice list is ``false``
(not chosen).
All input choice will be fully connected in the search phase. On exporting, the input choice will choose inputs based
on keys in ``choose_from``. If the keys were to be keys of LayerChoices, the top logit of the corresponding LayerChoice
will join the competition of input choice to compete against other logits. Otherwise, the logit will be assumed 0.
It's possible to cut branches by setting parameter ``choices`` in a particular position to ``-inf``. After softmax, the
value would be 0. Framework will ignore 0 values and not connect. Note that the gradient on the ``-inf`` location will
be 0. Since manipulations with ``-inf`` will be ``nan``, you need to handle the gradient update phase carefully.
Attributes
----------
choices: ParameterDict
dict that maps keys of LayerChoices to weighted-connection float tensors.
"""
def __init__(self, model):
super().__init__(model)
self.choices = nn.ParameterDict()
......
......@@ -19,6 +19,42 @@ class DartsTrainer(Trainer):
optimizer, num_epochs, dataset_train, dataset_valid,
mutator=None, batch_size=64, workers=4, device=None, log_frequency=None,
callbacks=None, arc_learning_rate=3.0E-4, unrolled=False):
"""
Initialize a DartsTrainer.
Parameters
----------
model : nn.Module
PyTorch model to be trained.
loss : callable
Receives logits and ground truth label, return a loss tensor.
metrics : callable
Receives logits and ground truth label, return a dict of metrics.
optimizer : Optimizer
The optimizer used for optimizing the model.
num_epochs : int
Number of epochs planned for training.
dataset_train : Dataset
Dataset for training. Will be split for training weights and architecture weights.
dataset_valid : Dataset
Dataset for testing.
mutator : DartsMutator
Use in case of customizing your own DartsMutator. By default will instantiate a DartsMutator.
batch_size : int
Batch size.
workers : int
Workers for data loading.
device : torch.device
``torch.device("cpu")`` or ``torch.device("cuda")``.
log_frequency : int
Step count per logging.
callbacks : list of Callback
list of callbacks to trigger at events.
arc_learning_rate : float
Learning rate of architecture parameters.
unrolled : float
``True`` if using second order optimization, else first order optimization.
"""
super().__init__(model, mutator if mutator is not None else DartsMutator(model),
loss, metrics, optimizer, num_epochs, dataset_train, dataset_valid,
batch_size, workers, device, log_frequency, callbacks)
......
......@@ -30,11 +30,41 @@ class StackedLSTMCell(nn.Module):
class EnasMutator(Mutator):
def __init__(self, model, lstm_size=64, lstm_num_layers=1, tanh_constant=1.5, cell_exit_extra_step=False,
skip_target=0.4, branch_bias=0.25):
skip_target=0.4, temperature=None, branch_bias=0.25, entropy_reduction="sum"):
"""
Initialize a EnasMutator.
Parameters
----------
model : nn.Module
PyTorch model.
lstm_size : int
Controller LSTM hidden units.
lstm_num_layers : int
Number of layers for stacked LSTM.
tanh_constant : float
Logits will be equal to ``tanh_constant * tanh(logits)``. Don't use ``tanh`` if this value is ``None``.
cell_exit_extra_step : bool
If true, RL controller will perform an extra step at the exit of each MutableScope, dump the hidden state
and mark it as the hidden state of this MutableScope. This is to align with the original implementation of paper.
skip_target : float
Target probability that skipconnect will appear.
temperature : float
Temperature constant that divides the logits.
branch_bias : float
Manual bias applied to make some operations more likely to be chosen.
Currently this is implemented with a hardcoded match rule that aligns with original repo.
If a mutable has a ``reduce`` in its key, all its op choices
that contains `conv` in their typename will receive a bias of ``+self.branch_bias`` initially; while others
receive a bias of ``-self.branch_bias``.
entropy_reduction : str
Can be one of ``sum`` and ``mean``. How the entropy of multi-input-choice is reduced.
"""
super().__init__(model)
self.lstm_size = lstm_size
self.lstm_num_layers = lstm_num_layers
self.tanh_constant = tanh_constant
self.temperature = temperature
self.cell_exit_extra_step = cell_exit_extra_step
self.skip_target = skip_target
self.branch_bias = branch_bias
......@@ -45,6 +75,8 @@ class EnasMutator(Mutator):
self.v_attn = nn.Linear(self.lstm_size, 1, bias=False)
self.g_emb = nn.Parameter(torch.randn(1, self.lstm_size) * 0.1)
self.skip_targets = nn.Parameter(torch.tensor([1.0 - self.skip_target, self.skip_target]), requires_grad=False) # pylint: disable=not-callable
assert entropy_reduction in ["sum", "mean"], "Entropy reduction must be one of sum and mean."
self.entropy_reduction = torch.sum if entropy_reduction == "sum" else torch.mean
self.cross_entropy_loss = nn.CrossEntropyLoss(reduction="none")
self.bias_dict = nn.ParameterDict()
......@@ -110,15 +142,17 @@ class EnasMutator(Mutator):
def _sample_layer_choice(self, mutable):
self._lstm_next_step()
logit = self.soft(self._h[-1])
if self.temperature is not None:
logit /= self.temperature
if self.tanh_constant is not None:
logit = self.tanh_constant * torch.tanh(logit)
if mutable.key in self.bias_dict:
logit += self.bias_dict[mutable.key]
branch_id = torch.multinomial(F.softmax(logit, dim=-1), 1).view(-1)
log_prob = self.cross_entropy_loss(logit, branch_id)
self.sample_log_prob += torch.sum(log_prob)
self.sample_log_prob += self.entropy_reduction(log_prob)
entropy = (log_prob * torch.exp(-log_prob)).detach() # pylint: disable=invalid-unary-operand-type
self.sample_entropy += torch.sum(entropy)
self.sample_entropy += self.entropy_reduction(entropy)
self._inputs = self.embedding(branch_id)
return F.one_hot(branch_id, num_classes=self.max_layer_choice).bool().view(-1)
......@@ -133,6 +167,8 @@ class EnasMutator(Mutator):
query = torch.cat(query, 0)
query = torch.tanh(query + self.attn_query(self._h[-1]))
query = self.v_attn(query)
if self.temperature is not None:
query /= self.temperature
if self.tanh_constant is not None:
query = self.tanh_constant * torch.tanh(query)
......@@ -153,7 +189,7 @@ class EnasMutator(Mutator):
log_prob = self.cross_entropy_loss(logit, index)
self._inputs = anchors[index.item()]
self.sample_log_prob += torch.sum(log_prob)
self.sample_log_prob += self.entropy_reduction(log_prob)
entropy = (log_prob * torch.exp(-log_prob)).detach() # pylint: disable=invalid-unary-operand-type
self.sample_entropy += torch.sum(entropy)
self.sample_entropy += self.entropy_reduction(entropy)
return skip.bool()
......@@ -2,11 +2,14 @@
# Licensed under the MIT license.
import logging
from itertools import cycle
import torch
import torch.nn as nn
import torch.optim as optim
from nni.nas.pytorch.trainer import Trainer
from nni.nas.pytorch.utils import AverageMeterGroup
from nni.nas.pytorch.utils import AverageMeterGroup, to_device
from .mutator import EnasMutator
logger = logging.getLogger(__name__)
......@@ -16,13 +19,68 @@ class EnasTrainer(Trainer):
def __init__(self, model, loss, metrics, reward_function,
optimizer, num_epochs, dataset_train, dataset_valid,
mutator=None, batch_size=64, workers=4, device=None, log_frequency=None, callbacks=None,
entropy_weight=0.0001, skip_weight=0.8, baseline_decay=0.999,
mutator_lr=0.00035, mutator_steps_aggregate=20, mutator_steps=50, aux_weight=0.4):
entropy_weight=0.0001, skip_weight=0.8, baseline_decay=0.999, child_steps=500,
mutator_lr=0.00035, mutator_steps_aggregate=20, mutator_steps=50, aux_weight=0.4,
test_arc_per_epoch=1):
"""
Initialize an EnasTrainer.
Parameters
----------
model : nn.Module
PyTorch model to be trained.
loss : callable
Receives logits and ground truth label, return a loss tensor.
metrics : callable
Receives logits and ground truth label, return a dict of metrics.
reward_function : callable
Receives logits and ground truth label, return a tensor, which will be feeded to RL controller as reward.
optimizer : Optimizer
The optimizer used for optimizing the model.
num_epochs : int
Number of epochs planned for training.
dataset_train : Dataset
Dataset for training. Will be split for training weights and architecture weights.
dataset_valid : Dataset
Dataset for testing.
mutator : EnasMutator
Use when customizing your own mutator or a mutator with customized parameters.
batch_size : int
Batch size.
workers : int
Workers for data loading.
device : torch.device
``torch.device("cpu")`` or ``torch.device("cuda")``.
log_frequency : int
Step count per logging.
callbacks : list of Callback
list of callbacks to trigger at events.
entropy_weight : float
Weight of sample entropy loss.
skip_weight : float
Weight of skip penalty loss.
baseline_decay : float
Decay factor of baseline. New baseline will be equal to ``baseline_decay * baseline_old + reward * (1 - baseline_decay)``.
child_steps : int
How many mini-batches for model training per epoch.
mutator_lr : float
Learning rate for RL controller.
mutator_steps_aggregate : int
Number of steps that will be aggregated into one mini-batch for RL controller.
mutator_steps : int
Number of mini-batches for each epoch of RL controller learning.
aux_weight : float
Weight of auxiliary head loss. ``aux_weight * aux_loss`` will be added to total loss.
test_arc_per_epoch : int
How many architectures are chosen for direct test after each epoch.
"""
super().__init__(model, mutator if mutator is not None else EnasMutator(model),
loss, metrics, optimizer, num_epochs, dataset_train, dataset_valid,
batch_size, workers, device, log_frequency, callbacks)
self.reward_function = reward_function
self.mutator_optim = optim.Adam(self.mutator.parameters(), lr=mutator_lr)
self.batch_size = batch_size
self.workers = workers
self.entropy_weight = entropy_weight
self.skip_weight = skip_weight
......@@ -30,32 +88,40 @@ class EnasTrainer(Trainer):
self.baseline = 0.
self.mutator_steps_aggregate = mutator_steps_aggregate
self.mutator_steps = mutator_steps
self.child_steps = child_steps
self.aux_weight = aux_weight
self.test_arc_per_epoch = test_arc_per_epoch
self.init_dataloader()
def init_dataloader(self):
n_train = len(self.dataset_train)
split = n_train // 10
indices = list(range(n_train))
train_sampler = torch.utils.data.sampler.SubsetRandomSampler(indices[:-split])
valid_sampler = torch.utils.data.sampler.SubsetRandomSampler(indices[-split:])
self.train_loader = torch.utils.data.DataLoader(self.dataset_train,
batch_size=batch_size,
batch_size=self.batch_size,
sampler=train_sampler,
num_workers=workers)
num_workers=self.workers)
self.valid_loader = torch.utils.data.DataLoader(self.dataset_train,
batch_size=batch_size,
batch_size=self.batch_size,
sampler=valid_sampler,
num_workers=workers)
num_workers=self.workers)
self.test_loader = torch.utils.data.DataLoader(self.dataset_valid,
batch_size=batch_size,
num_workers=workers)
batch_size=self.batch_size,
num_workers=self.workers)
self.train_loader = cycle(self.train_loader)
self.valid_loader = cycle(self.valid_loader)
def train_one_epoch(self, epoch):
# Sample model and train
self.model.train()
self.mutator.eval()
meters = AverageMeterGroup()
for step, (x, y) in enumerate(self.train_loader):
x, y = x.to(self.device), y.to(self.device)
for step in range(1, self.child_steps + 1):
x, y = next(self.train_loader)
x, y = to_device(x, self.device), to_device(y, self.device)
self.optimizer.zero_grad()
with torch.no_grad():
......@@ -71,55 +137,71 @@ class EnasTrainer(Trainer):
loss = self.loss(logits, y)
loss = loss + self.aux_weight * aux_loss
loss.backward()
nn.utils.clip_grad_norm_(self.model.parameters(), 5.)
self.optimizer.step()
metrics["loss"] = loss.item()
meters.update(metrics)
if self.log_frequency is not None and step % self.log_frequency == 0:
logger.info("Model Epoch [%s/%s] Step [%s/%s] %s", epoch + 1,
self.num_epochs, step + 1, len(self.train_loader), meters)
logger.info("Model Epoch [%d/%d] Step [%d/%d] %s", epoch + 1,
self.num_epochs, step, self.child_steps, meters)
# Train sampler (mutator)
self.model.eval()
self.mutator.train()
meters = AverageMeterGroup()
mutator_step, total_mutator_steps = 0, self.mutator_steps * self.mutator_steps_aggregate
while mutator_step < total_mutator_steps:
for step, (x, y) in enumerate(self.valid_loader):
x, y = x.to(self.device), y.to(self.device)
for mutator_step in range(1, self.mutator_steps + 1):
self.mutator_optim.zero_grad()
for step in range(1, self.mutator_steps_aggregate + 1):
x, y = next(self.valid_loader)
x, y = to_device(x, self.device), to_device(y, self.device)
self.mutator.reset()
with torch.no_grad():
logits = self.model(x)
metrics = self.metrics(logits, y)
reward = self.reward_function(logits, y)
if self.entropy_weight is not None:
reward += self.entropy_weight * self.mutator.sample_entropy
if self.entropy_weight:
reward += self.entropy_weight * self.mutator.sample_entropy.item()
self.baseline = self.baseline * self.baseline_decay + reward * (1 - self.baseline_decay)
self.baseline = self.baseline.detach().item()
loss = self.mutator.sample_log_prob * (reward - self.baseline)
if self.skip_weight:
loss += self.skip_weight * self.mutator.sample_skip_penalty
metrics["reward"] = reward
metrics["loss"] = loss.item()
metrics["ent"] = self.mutator.sample_entropy.item()
metrics["log_prob"] = self.mutator.sample_log_prob.item()
metrics["baseline"] = self.baseline
metrics["skip"] = self.mutator.sample_skip_penalty
loss = loss / self.mutator_steps_aggregate
loss /= self.mutator_steps_aggregate
loss.backward()
meters.update(metrics)
if mutator_step % self.mutator_steps_aggregate == 0:
self.mutator_optim.step()
self.mutator_optim.zero_grad()
cur_step = step + (mutator_step - 1) * self.mutator_steps_aggregate
if self.log_frequency is not None and cur_step % self.log_frequency == 0:
logger.info("RL Epoch [%d/%d] Step [%d/%d] [%d/%d] %s", epoch + 1, self.num_epochs,
mutator_step, self.mutator_steps, step, self.mutator_steps_aggregate,
meters)
if self.log_frequency is not None and step % self.log_frequency == 0:
logger.info("RL Epoch [%s/%s] Step [%s/%s] %s", epoch + 1, self.num_epochs,
mutator_step // self.mutator_steps_aggregate + 1, self.mutator_steps, meters)
mutator_step += 1
if mutator_step >= total_mutator_steps:
break
nn.utils.clip_grad_norm_(self.mutator.parameters(), 5.)
self.mutator_optim.step()
def validate_one_epoch(self, epoch):
pass
with torch.no_grad():
for arc_id in range(self.test_arc_per_epoch):
meters = AverageMeterGroup()
for x, y in self.test_loader:
x, y = to_device(x, self.device), to_device(y, self.device)
self.mutator.reset()
logits = self.model(x)
if isinstance(logits, tuple):
logits, _ = logits
metrics = self.metrics(logits, y)
loss = self.loss(logits, y)
metrics["loss"] = loss.item()
meters.update(metrics)
logger.info("Test Epoch [%d/%d] Arc [%d/%d] Summary %s",
epoch + 1, self.num_epochs, arc_id + 1, self.test_arc_per_epoch,
meters.summary())
......@@ -41,18 +41,18 @@ class FixedArchitecture(Mutator):
return self._fixed_arc
def _encode_tensor(data, device):
def _encode_tensor(data):
if isinstance(data, list):
if all(map(lambda o: isinstance(o, bool), data)):
return torch.tensor(data, dtype=torch.bool, device=device) # pylint: disable=not-callable
return torch.tensor(data, dtype=torch.bool) # pylint: disable=not-callable
else:
return torch.tensor(data, dtype=torch.float, device=device) # pylint: disable=not-callable
return torch.tensor(data, dtype=torch.float) # pylint: disable=not-callable
if isinstance(data, dict):
return {k: _encode_tensor(v, device) for k, v in data.items()}
return {k: _encode_tensor(v) for k, v in data.items()}
return data
def apply_fixed_architecture(model, fixed_arc_path, device=None):
def apply_fixed_architecture(model, fixed_arc_path):
"""
Load architecture from `fixed_arc_path` and apply to model.
......@@ -62,21 +62,16 @@ def apply_fixed_architecture(model, fixed_arc_path, device=None):
Model with mutables.
fixed_arc_path : str
Path to the JSON that stores the architecture.
device : torch.device
Architecture weights will be transfered to `device`.
Returns
-------
FixedArchitecture
"""
if device is None:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if isinstance(fixed_arc_path, str):
with open(fixed_arc_path, "r") as f:
fixed_arc = json.load(f)
fixed_arc = _encode_tensor(fixed_arc, device)
fixed_arc = _encode_tensor(fixed_arc)
architecture = FixedArchitecture(model, fixed_arc)
architecture.to(device)
architecture.reset()
return architecture
......@@ -159,7 +159,7 @@ class InputChoice(Mutable):
"than number of candidates."
self.n_candidates = n_candidates
self.choose_from = choose_from
self.choose_from = choose_from.copy()
self.n_chosen = n_chosen
self.reduction = reduction
self.return_mask = return_mask
......
......@@ -211,6 +211,7 @@ class SPOSEvolution(Tuner):
Parameters
----------
result : dict
Chosen architectures to be exported.
"""
os.makedirs("checkpoints", exist_ok=True)
for i, cand in enumerate(result):
......
......@@ -17,6 +17,7 @@ class SPOSSupernetTrainingMutator(RandomMutator):
Parameters
----------
model : nn.Module
PyTorch model.
flops_func : callable
Callable that takes a candidate from `sample_search` and returns its candidate. When `flops_func`
is None, functions related to flops will be deactivated.
......
......@@ -21,6 +21,37 @@ class SPOSSupernetTrainer(Trainer):
optimizer, num_epochs, train_loader, valid_loader,
mutator=None, batch_size=64, workers=4, device=None, log_frequency=None,
callbacks=None):
"""
Parameters
----------
model : nn.Module
Model with mutables.
mutator : Mutator
A mutator object that has been initialized with the model.
loss : callable
Called with logits and targets. Returns a loss tensor.
metrics : callable
Returns a dict that maps metrics keys to metrics data.
optimizer : Optimizer
Optimizer that optimizes the model.
num_epochs : int
Number of epochs of training.
train_loader : iterable
Data loader of training. Raise ``StopIteration`` when one epoch is exhausted.
dataset_valid : iterable
Data loader of validation. Raise ``StopIteration`` when one epoch is exhausted.
batch_size : int
Batch size.
workers: int
Number of threads for data preprocessing. Not used for this trainer. Maybe removed in future.
device : torch.device
Device object. Either ``torch.device("cuda")`` or ``torch.device("cpu")``. When ``None``, trainer will
automatic detects GPU and selects GPU first.
log_frequency : int
Number of mini-batches to log metrics.
callbacks : list of Callback
Callbacks to plug into the trainer. See Callbacks.
"""
assert torch.cuda.is_available()
super().__init__(model, mutator if mutator is not None else SPOSSupernetTrainingMutator(model),
loss, metrics, optimizer, num_epochs, None, None,
......
......@@ -52,7 +52,7 @@ class Trainer(BaseTrainer):
workers : int
Number of workers used in data preprocessing.
device : torch.device
Device object. Either `torch.device("cuda")` or torch.device("cpu")`. When `None`, trainer will
Device object. Either ``torch.device("cuda")`` or ``torch.device("cpu")``. When ``None``, trainer will
automatic detects GPU and selects GPU first.
log_frequency : int
Number of mini-batches to log metrics.
......@@ -96,12 +96,12 @@ class Trainer(BaseTrainer):
callback.on_epoch_begin(epoch)
# training
_logger.info("Epoch %d Training", epoch)
_logger.info("Epoch %d Training", epoch + 1)
self.train_one_epoch(epoch)
if validate:
# validation
_logger.info("Epoch %d Validating", epoch)
_logger.info("Epoch %d Validating", epoch + 1)
self.validate_one_epoch(epoch)
for callback in self.callbacks:
......
......@@ -4,6 +4,8 @@
import logging
from collections import OrderedDict
import torch
_counter = 0
_logger = logging.getLogger(__name__)
......@@ -15,7 +17,22 @@ def global_mutable_counting():
return _counter
def to_device(obj, device):
if torch.is_tensor(obj):
return obj.to(device)
if isinstance(obj, tuple):
return tuple(to_device(t, device) for t in obj)
if isinstance(obj, list):
return [to_device(t, device) for t in obj]
if isinstance(obj, dict):
return {k: to_device(v, device) for k, v in obj.items()}
if isinstance(obj, (int, float, str)):
return obj
raise ValueError("'%s' has unsupported type '%s'" % (obj, type(obj)))
class AverageMeterGroup:
"""Average meter group for multiple average meters"""
def __init__(self):
self.meters = OrderedDict()
......@@ -33,7 +50,10 @@ class AverageMeterGroup:
return self.meters[item]
def __str__(self):
return " ".join(str(v) for _, v in self.meters.items())
return " ".join(str(v) for v in self.meters.values())
def summary(self):
return " ".join(v.summary() for v in self.meters.values())
class AverageMeter:
......@@ -72,6 +92,10 @@ class AverageMeter:
fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
return fmtstr.format(**self.__dict__)
def summary(self):
fmtstr = '{name}: {avg' + self.fmt + '}'
return fmtstr.format(**self.__dict__)
class StructuredMutableTreeNode:
"""
......
......@@ -91,7 +91,8 @@ class Compare extends React.Component<CompareProps, {}> {
},
yAxis: {
type: 'value',
name: 'Metric'
name: 'Metric',
scale: true
},
series: trialIntermediate
};
......
......@@ -28,12 +28,11 @@ class SuccessTable extends React.Component<SuccessTableProps, {}> {
{
title: 'Trial No.',
dataIndex: 'sequenceId',
width: 140,
className: 'tableHead'
}, {
title: 'ID',
dataIndex: 'id',
width: 60,
width: 80,
className: 'tableHead leftTitle',
render: (text: string, record: TableRecord): React.ReactNode => {
return (
......
......@@ -517,7 +517,7 @@ def manage_stopped_experiment(args, mode):
experiment_id = None
#find the latest stopped experiment
if not args.id:
print_error('Please set experiment id! \nYou could use \'nnictl {0} {id}\' to {0} a stopped experiment!\n' \
print_error('Please set experiment id! \nYou could use \'nnictl {0} id\' to {0} a stopped experiment!\n' \
'You could use \'nnictl experiment list --all\' to show all experiments!'.format(mode))
exit(1)
else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment