"include/ck/utility/static_buffer.hpp" did not exist on "78b987fbd6a7897ee9827187a231441794b13490"
Unverified Commit 3ec26b40 authored by liuzhe-lz's avatar liuzhe-lz Committed by GitHub
Browse files

Merge master into dev-retiarii (#3178)

parent d165905d
......@@ -580,17 +580,55 @@ class QuantType:
"""
Enum class for quantization type.
"""
QUANT_INPUT = 0
QUANT_WEIGHT = 1
QUANT_OUTPUT = 2
QUANT_INPUT = 'input'
QUANT_WEIGHT = 'weight'
QUANT_OUTPUT = 'output'
class QuantGrad(torch.autograd.Function):
"""
Base class for overriding backward function of quantization operation.
"""
@classmethod
def _quantize(cls, x, scale, zero_point):
"""
Reference function for quantizing x -- non-clamped.
Parameters
----------
x : Tensor
tensor to be quantized
scale : Tensor
scale for quantizing x
zero_point : Tensor
zero_point for quantizing x
Returns
-------
tensor
quantized x without clamped
"""
return ((x / scale) + zero_point).round()
@classmethod
def get_bits_length(cls, config, quant_type):
"""
Get bit for quantize config
Parameters
----------
config : Dict
the configuration for quantization
quant_type : str
quant type
Returns
-------
int
n-bits for quantization configuration
"""
if isinstance(config["quant_bits"], int):
return config["quant_bits"]
else:
return config["quant_bits"].get(quant_type)
@staticmethod
def quant_backward(tensor, grad_output, quant_type):
def quant_backward(tensor, grad_output, scale, zero_point, qmin, qmax):
"""
This method should be overrided by subclass to provide customized backward function,
default implementation is Straight-Through Estimator
......@@ -600,32 +638,45 @@ class QuantGrad(torch.autograd.Function):
input of quantization operation
grad_output : Tensor
gradient of the output of quantization operation
quant_type : QuantType
scale : Tensor
the type of quantization, it can be `QuantType.QUANT_INPUT`, `QuantType.QUANT_WEIGHT`, `QuantType.QUANT_OUTPUT`,
you can define different behavior for different types.
zero_point : Tensor
zero_point for quantizing tensor
qmin : Tensor
quant_min for quantizing tensor
qmax : Tensor
quant_max for quantizng tensor
Returns
-------
tensor
gradient of the input of quantization operation
"""
tensor_q = QuantGrad._quantize(tensor, scale, zero_point)
mask = (tensor_q < qmin) | (tensor_q > qmax)
grad_output[mask] = 0
return grad_output
@staticmethod
def forward(ctx, tensor, quant_type, wrapper, **kwargs):
ctx.save_for_backward(tensor, torch.Tensor([quant_type]))
if quant_type == QuantType.QUANT_INPUT:
return wrapper.quantizer.quantize_input(tensor, wrapper, **kwargs)
output = wrapper.quantizer.quantize_input(tensor, wrapper, **kwargs)
elif quant_type == QuantType.QUANT_WEIGHT:
return wrapper.quantizer.quantize_weight(wrapper, **kwargs)
output = wrapper.quantizer.quantize_weight(wrapper, **kwargs)
elif quant_type == QuantType.QUANT_OUTPUT:
return wrapper.quantizer.quantize_output(tensor, wrapper, **kwargs)
output = wrapper.quantizer.quantize_output(tensor, wrapper, **kwargs)
else:
raise ValueError("unrecognized QuantType.")
bits = QuantGrad.get_bits_length(wrapper.config, quant_type)
qmin, qmax = torch.Tensor([0], device=tensor.device), torch.Tensor([(1 << bits) - 1], device=tensor.device)
ctx.save_for_backward(tensor, wrapper.module.scale, wrapper.module.zero_point, qmin, qmax)
return output
@classmethod
def backward(cls, ctx, grad_output):
tensor, quant_type = ctx.saved_variables
output = cls.quant_backward(tensor, grad_output, quant_type)
tensor, scale, zero_point, qmin, qmax = ctx.saved_variables
output = cls.quant_backward(tensor, grad_output, scale, zero_point, qmin, qmax)
return output, None, None, None
def _check_weight(module):
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .apply_compression import apply_compression_results
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import logging
import torch
logger = logging.getLogger('torch apply compression')
def apply_compression_results(model, masks_file, map_location=None):
"""
Apply the masks from ```masks_file``` to the model
Note: this API is for inference, because it simply multiplies weights with
corresponding masks when this API is called.
Parameters
----------
model : torch.nn.Module
The model to be compressed
masks_file : str
The path of the mask file
map_location : str
the device on which masks are placed, same to map_location in ```torch.load```
"""
masks = torch.load(masks_file, map_location)
for name, module in model.named_modules():
if name in masks:
module.weight.data = module.weight.data.mul_(masks[name]['weight'])
if hasattr(module, 'bias') and module.bias is not None and 'bias' in masks[name]:
module.bias.data = module.bias.data.mul_(masks[name]['bias'])
\ No newline at end of file
......@@ -10,6 +10,7 @@ _logger = logging.getLogger(__name__)
replace_module = {
'BatchNorm2d': lambda module, mask: replace_batchnorm2d(module, mask),
'Conv2d': lambda module, mask: replace_conv2d(module, mask),
'ConvTranspose2d': lambda module, mask: replace_convtranspose2d(module, mask),
'MaxPool2d': lambda module, mask: no_replace(module, mask),
'AvgPool2d': lambda module, mask: no_replace(module, mask),
'AdaptiveAvgPool2d': lambda module, mask: no_replace(module, mask),
......@@ -22,6 +23,7 @@ replace_module = {
'Dropout3d': lambda module, mask: no_replace(module, mask)
}
def no_replace(module, mask):
"""
No need to replace
......@@ -29,6 +31,7 @@ def no_replace(module, mask):
_logger.debug("no need to replace")
return module
def replace_linear(linear, mask):
"""
Parameters
......@@ -54,11 +57,13 @@ def replace_linear(linear, mask):
out_features=linear.out_features,
bias=linear.bias is not None)
new_linear.to(linear.weight.device)
new_linear.weight.data = torch.index_select(linear.weight.data, -1, index.to(linear.weight.device))
new_linear.weight.data = torch.index_select(
linear.weight.data, -1, index.to(linear.weight.device))
if linear.bias is not None:
new_linear.bias.data.copy_(linear.bias.data)
return new_linear
def replace_batchnorm2d(norm, mask):
"""
Parameters
......@@ -87,10 +92,13 @@ def replace_batchnorm2d(norm, mask):
new_norm.weight.data = torch.index_select(norm.weight.data, 0, index)
new_norm.bias.data = torch.index_select(norm.bias.data, 0, index)
if norm.track_running_stats:
new_norm.running_mean.data = torch.index_select(norm.running_mean.data, 0, index)
new_norm.running_var.data = torch.index_select(norm.running_var.data, 0, index)
new_norm.running_mean.data = torch.index_select(
norm.running_mean.data, 0, index)
new_norm.running_var.data = torch.index_select(
norm.running_var.data, 0, index)
return new_norm
def replace_conv2d(conv, mask):
"""
Parameters
......@@ -121,7 +129,8 @@ def replace_conv2d(conv, mask):
# remove groups for depthwise layers
assert in_channels == out_channels
groups = in_channels
_logger.debug("replace conv2d %s with in_channels: %d, out_channels: %d", mask.module_name, in_channels, out_channels)
_logger.debug("replace conv2d %s with in_channels: %d, out_channels: %d",
mask.module_name, in_channels, out_channels)
new_conv = torch.nn.Conv2d(in_channels=in_channels,
out_channels=out_channels,
kernel_size=conv.kernel_size,
......@@ -136,9 +145,11 @@ def replace_conv2d(conv, mask):
tmp_weight_data = tmp_bias_data = None
if mask.output_mask is not None:
tmp_weight_data = torch.index_select(conv.weight.data, 0, out_channels_index)
tmp_weight_data = torch.index_select(
conv.weight.data, 0, out_channels_index)
if conv.bias is not None:
tmp_bias_data = torch.index_select(conv.bias.data, 0, out_channels_index)
tmp_bias_data = torch.index_select(
conv.bias.data, 0, out_channels_index)
else:
tmp_weight_data = conv.weight.data
# For the convolutional layers that have more than one group
......@@ -152,24 +163,120 @@ def replace_conv2d(conv, mask):
for groupid in range(conv.groups):
start = groupid * input_step
end = (groupid + 1) * input_step
current_input_index = list(filter(lambda x: start <= x and x < end, in_channels_index.tolist()))
current_input_index = list(
filter(lambda x: start <= x and x < end, in_channels_index.tolist()))
if not current_input_index:
# there is no kept channel in current group
continue
# TODO bug here, the groups is directly get from conv.groups, if the whole group is removed,
# then the number of groups in the new_conv also need to change
raise Exception(
" Donnot support removing the whole group filter except in the depth-wise conv temporarily")
# shift the global index into the group index
current_input_index = [x-start for x in current_input_index]
# if the groups is larger than 1, the input channels of each
# group should be pruned evenly.
assert len(current_input_index) == in_channels_group, \
'Input channels of each group are not pruned evenly'
current_input_index = torch.tensor(current_input_index).to(tmp_weight_data.device) # pylint: disable=not-callable
current_input_index = torch.tensor(current_input_index).to(tmp_weight_data.device) # pylint: disable=not-callable
f_start = groupid * filter_step
f_end = (groupid + 1) * filter_step
new_conv.weight.data[f_start:f_end] = torch.index_select(tmp_weight_data[f_start:f_end], 1, current_input_index)
new_conv.weight.data[f_start:f_end] = torch.index_select(
tmp_weight_data[f_start:f_end], 1, current_input_index)
else:
new_conv.weight.data.copy_(tmp_weight_data)
if conv.bias is not None:
new_conv.bias.data.copy_(conv.bias.data if tmp_bias_data is None else tmp_bias_data)
new_conv.bias.data.copy_(
conv.bias.data if tmp_bias_data is None else tmp_bias_data)
return new_conv
def replace_convtranspose2d(convtrans, mask):
"""
We need anothor replace function for
convtranspose2d, because the layout of
the weight is different from traditional
conv layers. The layout of the weight is [N_in, N_out, ksize_1, ksize_2]
Parameters
----------
convtrans : torch.nn.ConvTranspose2d
The conv2d module to be replaced
mask : ModuleMasks
The masks of this module
Returns
-------
torch.nn.ConvTranspose2d
The new conv2d module
"""
assert isinstance(mask, ModuleMasks)
assert isinstance(convtrans, torch.nn.ConvTranspose2d)
if mask.input_mask is None:
in_channels = convtrans.in_channels
else:
in_channels_index = mask.input_mask.mask_index[1]
in_channels = in_channels_index.size(0)
if mask.output_mask is None:
out_channels = convtrans.out_channels
else:
out_channels_index = mask.output_mask.mask_index[1]
out_channels = out_channels_index.size(0)
groups = convtrans.groups
# check if can remove the whole group of filters
if convtrans.in_channels == convtrans.out_channels == convtrans.groups:
# remove groups for depthwise layers
# this needs the group dependency to be fixed before the speedup
assert in_channels == out_channels
groups = in_channels
_logger.debug('Replace convtranspose2d %s with in_channels:%d out_channels:%d',
mask.module_name, in_channels, out_channels)
new_convtrans = torch.nn.ConvTranspose2d(in_channels=in_channels,
out_channels=out_channels,
kernel_size=convtrans.kernel_size,
stride=convtrans.stride,
padding=convtrans.padding,
dilation=convtrans.dilation,
groups=groups,
bias=convtrans.bias is not None,
padding_mode=convtrans.padding_mode)
new_convtrans.to(convtrans.weight.device)
tmp_weight_data = None
if mask.input_mask is not None:
# in convtranspose2d we need to select the input channel first
tmp_weight_data = torch.index_select(
convtrans.weight.data, 0, in_channels_index)
else:
tmp_weight_data = convtrans.weight.data
# we need to handle the output channel group by group like the conv layer
out_step = int(convtrans.out_channels / convtrans.groups)
out_channel_group = int(out_channels/groups)
new_in_per_group = int(in_channels/groups)
if mask.output_mask is not None and not(in_channels == out_channels == groups):
for groupid in range(convtrans.groups):
start = groupid * out_step
end = (groupid + 1) * out_step
current_output_index = list(
filter(lambda x: start <= x and x < end, out_channels_index.tolist()))
# we need to shift the index into the group-wise
current_output_index = [x-start for x in current_output_index]
if not current_output_index:
# No kept channel in the current group
raise Exception(
" Donnot support removing the whole group filter except in the depth-wise conv temporarily")
assert len(current_output_index) == out_channel_group, \
'Output channel of each group should be the same after pruning'
current_output_index = torch.tensor(current_output_index).to(tmp_weight_data.device) # pylint: disable=not-callable
new_start = groupid * new_in_per_group
new_end = (groupid + 1) * new_in_per_group
new_convtrans.weight.data[new_start:new_end] = torch.index_select(
tmp_weight_data[new_start:new_end], 1, current_output_index)
else:
new_convtrans.weight.data.copy_(tmp_weight_data)
if convtrans.bias is not None:
if mask.output_mask is not None:
new_convtrans.bias.data[:] = torch.index_select(
convtrans.bias.data, 0, out_channels_index)
else:
new_convtrans.bias.data.copy_(convtrans.bias.data)
return new_convtrans
......@@ -13,6 +13,7 @@ _logger = logging.getLogger(__name__)
conv_prune_dim = -1
def set_conv_prune_dim(dim):
"""
Parameters:
......@@ -23,6 +24,7 @@ def set_conv_prune_dim(dim):
global conv_prune_dim
conv_prune_dim = dim
class CoarseMask:
"""
Coarse grained mask for a given tensor, here tensor could be weights,
......@@ -228,6 +230,7 @@ Infer input and output shape of a module/function from its weight mask
infer_from_mask = {
'BatchNorm2d': lambda module_masks, mask: batchnorm2d_mask(module_masks, mask),
'Conv2d': lambda module_masks, mask: conv2d_mask(module_masks, mask),
'ConvTranspose2d': lambda module_masks, mask: convtranspose2d_mask(module_masks, mask),
'Linear': lambda module_masks, mask, shape: linear_mask(module_masks, mask, shape)
}
......@@ -246,6 +249,7 @@ infer_from_inshape = {
'aten::relu_': lambda module_masks, mask: relu_inshape(module_masks, mask),
'aten::sigmoid': lambda module_masks, mask: relu_inshape(module_masks, mask),
'Conv2d': lambda module_masks, mask: conv2d_inshape(module_masks, mask),
'ConvTranspose2d': lambda module_masks, mask: convtranspose2d_inshape(module_masks, mask),
'MaxPool2d': lambda module_masks, mask: maxpool2d_inshape(module_masks, mask),
'aten::max_pool2d': lambda module_masks, mask: maxpool2d_inshape(module_masks, mask),
'aten::avg_pool2d': lambda module_masks, mask: maxpool2d_inshape(module_masks, mask),
......@@ -277,6 +281,7 @@ Infer input and weight shape of a module/function from its output shape
"""
infer_from_outshape = {
'Conv2d': lambda module_masks, mask: conv2d_outshape(module_masks, mask),
'ConvTranspose2d': lambda module_masks, mask: convtranspose2d_outshape(module_masks, mask),
'BatchNorm2d': lambda module_masks, mask: batchnorm2d_outshape(module_masks, mask),
'MaxPool2d': lambda module_masks, mask: maxpool2d_outshape(module_masks, mask),
......@@ -306,6 +311,7 @@ infer_from_outshape = {
'aten::dropout': lambda module_masks, mask: dropout_outshape(module_masks, mask)
}
def dropout_inshape(module_masks, mask):
if module_masks.input_mask is None:
module_masks.set_input_mask(mask)
......@@ -325,6 +331,7 @@ def dropout_inshape(module_masks, mask):
module_masks.set_output_mask(mask)
return module_masks.output_mask
def dropout_outshape(module_masks, mask):
if module_masks.output_mask is None:
module_masks.set_output_mask(mask)
......@@ -335,6 +342,7 @@ def dropout_outshape(module_masks, mask):
return module_masks.output_mask
def cat_inshape(module_masks, mask, cat_info, last_visited):
"""
Inference the output mask of the cat operation from the
......@@ -433,6 +441,7 @@ def add_inshape(module_masks, mask):
raise Exception('Mask conflict happenes!')
return None
def add_outshape(module_masks, mask):
"""
Inference the input mask of the add operation from the
......@@ -445,9 +454,11 @@ def add_outshape(module_masks, mask):
module_masks.set_input_mask(mask)
return mask
else:
assert all(module_masks.output_mask.mask_index[1] == mask.mask_index[1])
assert all(
module_masks.output_mask.mask_index[1] == mask.mask_index[1])
return mask
def batchnorm2d_inshape(module_masks, mask):
"""
We assume only the second dimension has coarse grained mask
......@@ -477,6 +488,7 @@ def batchnorm2d_inshape(module_masks, mask):
module_masks.set_param_masks('bias', weight_cmask)
return mask
def batchnorm2d_outshape(module_masks, mask):
"""
We assume only the second dimension has coarse grained mask
......@@ -577,6 +589,7 @@ def view_inshape(module_masks, mask, shape):
module_masks.set_output_mask(output_cmask)
return output_cmask
def view_outshape(module_masks, mask, shape):
"""
Parameters
......@@ -614,12 +627,14 @@ def view_outshape(module_masks, mask, shape):
return input_cmask
def size_inshape(module_masks, mask):
"""
No need to do anything for this ```size``` op
"""
return None
def mean_inshape(module_masks, mask, shape):
"""
Similar to view operation, currently mask inference only supports
......@@ -642,6 +657,7 @@ def mean_inshape(module_masks, mask, shape):
module_masks.set_output_mask(output_cmask)
return output_cmask
def mean_outshape(module_masks, mask, shape):
"""
Similar to view operation, currently mask inference only supports
......@@ -662,6 +678,7 @@ def mean_outshape(module_masks, mask, shape):
module_masks.set_input_mask(input_cmask)
return input_cmask
def maxpool2d_inshape(module_masks, mask):
"""
Assume only the second dimension is masked
......@@ -690,6 +707,7 @@ def maxpool2d_inshape(module_masks, mask):
module_masks.set_output_mask(mask)
return mask
def maxpool2d_outshape(module_masks, mask):
"""
Assume only the second dimension is masked
......@@ -714,6 +732,7 @@ def maxpool2d_outshape(module_masks, mask):
module_masks.set_output_mask(mask)
return mask
def relu_inshape(module_masks, mask):
"""
Parameters
......@@ -737,6 +756,7 @@ def relu_inshape(module_masks, mask):
module_masks.set_output_mask(mask)
return mask
def relu_outshape(module_masks, mask):
"""
Parameters
......@@ -754,11 +774,13 @@ def relu_outshape(module_masks, mask):
assert isinstance(mask, CoarseMask)
if module_masks.output_mask is not None:
# mask conflict should be solved before speedup
assert all(module_masks.output_mask.mask_index[1] == mask.mask_index[1])
assert all(
module_masks.output_mask.mask_index[1] == mask.mask_index[1])
module_masks.set_input_mask(mask)
module_masks.set_output_mask(mask)
return mask
def batchnorm2d_mask(module_masks, mask):
"""
Infer input and output shape from weight mask
......@@ -792,6 +814,7 @@ def batchnorm2d_mask(module_masks, mask):
module_masks.set_output_mask(output_cmask)
return input_cmask, output_cmask
def linear_mask(module_masks, mask, shape):
"""
Infer input and output shape from weight mask with limitations:
......@@ -825,6 +848,7 @@ def linear_mask(module_masks, mask, shape):
module_masks.set_input_mask(input_cmask)
return input_cmask, None
def conv2d_mask(module_masks, mask):
"""
Infer input and output shape from weight mask
......@@ -863,8 +887,9 @@ def conv2d_mask(module_masks, mask):
weight_mask = mask['weight']
sum_idx = (1, 2, 3) if dim == 0 else (0, 2, 3)
index = torch.nonzero(weight_mask.abs().sum(sum_idx) != 0, as_tuple=True)[0]
if len(index) == weight_mask.shape[dim]: # full mask
index = torch.nonzero(weight_mask.abs().sum(
sum_idx) != 0, as_tuple=True)[0]
if len(index) == weight_mask.shape[dim]: # full mask
index = None
if index is None:
......@@ -882,7 +907,8 @@ def conv2d_mask(module_masks, mask):
bias_cmask.add_index_mask(dim=0, index=bias_index)
return index, weight_cmask, bias_cmask
index, weight_cmask, bias_cmask = convert_to_coarse_mask(mask, dim=conv_prune_dim)
index, weight_cmask, bias_cmask = convert_to_coarse_mask(
mask, dim=conv_prune_dim)
if index is None:
# TODO: fine grained mask speedup
......@@ -910,7 +936,8 @@ def conv2d_mask(module_masks, mask):
module_masks.set_input_mask(io_cmask)
else:
assert module_masks.input_mask == io_cmask
return module_masks.input_mask, None
return module_masks.input_mask, None
def conv2d_inshape(module_masks, mask):
"""
......@@ -972,7 +999,8 @@ def conv2d_outshape(module_masks, mask):
# mask conflict should be solved by fix_mask_conflict before speedup
# mask and module_masks.output_mask may have different number of dimensions
# since they could be passed by linear or conv2d
assert all(module_masks.output_mask.mask_index[1] == mask.mask_index[1])
assert all(
module_masks.output_mask.mask_index[1] == mask.mask_index[1])
weight_cmask = CoarseMask(num_dim=4)
weight_cmask.add_index_mask(dim=0, index=mask.mask_index[1])
......@@ -988,3 +1016,74 @@ def conv2d_outshape(module_masks, mask):
module_masks.input_mask = mask
return mask
return None
def convtranspose2d_mask(module_masks, mask):
# TODO support the Convtranspose2d Pruning for the L1FilterPruner
raise Exception(
"Current Filter pruner cannot prune the ConvTranspose2d, will support pruning ConvTranspose2d later")
def convtranspose2d_inshape(module_masks, mask):
"""
Shape change of input tensor does not affect the shape of its output tensor
Parameters
----------
module_masks : ModuleMasks
The ModuleMasks instance of the conv2d
mask : CoarseMask
The mask of its input tensor
Returns
-------
CoarseMask
The mask of its output tensor
"""
assert isinstance(mask, CoarseMask)
if module_masks.input_mask is None:
module_masks.set_input_mask(mask)
else:
# the same conv layer may be accessed more
# than once, such as a concat operation.
# mask conflict should be solved by fix_mask_conflict before speedup
assert module_masks.input_mask == mask
# shape changes pass through depths wise conv layers
m = module_masks.module
if m.in_channels == m.out_channels == m.groups:
module_masks.output_mask = mask
module_masks.input_mask = mask
return mask
return None
def convtranspose2d_outshape(module_masks, mask):
assert isinstance(mask, CoarseMask)
assert mask.mask_index[1] is not None
assert mask.mask_index[0] is None
assert mask.mask_index[2] is None
assert mask.mask_index[3] is None
if module_masks.output_mask is None:
module_masks.output_mask = mask
else:
# mask conflict should be solved by fix_mask_conflict before speedup
# mask and module_masks.output_mask may have different number of dimensions
# since they could be passed by linear or conv2d
assert all(
module_masks.output_mask.mask_index[1] == mask.mask_index[1])
weight_cmask = CoarseMask(num_dim=4)
# Note the memory layout of Convtranspose2d is C_in, C_out, k1, k2
weight_cmask.add_index_mask(dim=1, index=mask.mask_index[1])
bias_cmask = CoarseMask(num_dim=1)
bias_cmask.add_index_mask(dim=0, index=mask.mask_index[1])
module_masks.set_param_masks('weight', weight_cmask)
module_masks.set_param_masks('bias', bias_cmask)
# shape changes pass through depths wise conv layers
m = module_masks.module
if m.in_channels == m.out_channels == m.groups:
module_masks.output_mask = mask
module_masks.input_mask = mask
return mask
return None
......@@ -9,6 +9,7 @@ from .utils import get_module_by_name
# logging.basicConfig(level = logging.DEBUG)
_logger = logging.getLogger(__name__)
def fix_mask_conflict(masks, model=None, dummy_input=None, traced=None):
"""
MaskConflict fix the mask conflict for the channel dependencies
......@@ -50,6 +51,7 @@ def fix_mask_conflict(masks, model=None, dummy_input=None, traced=None):
masks = padding_cat_mask.fix_mask()
return masks, fix_channel_mask.conv_prune_dim
class MaskFix:
def __init__(self, masks, model=None, dummy_input=None, traced=None):
# check if the parameters are valid
......@@ -74,6 +76,7 @@ class MaskFix:
"""
torch.save(self.masks, path)
class CatMaskPadding(MaskFix):
def __init__(self, masks, model, dummy_input=None, traced=None):
"""
......@@ -100,7 +103,8 @@ class CatMaskPadding(MaskFix):
super(CatMaskPadding, self).__init__(masks, model, dummy_input, traced)
def fix_mask(self):
cat_padding_depen = CatPaddingDependency(self.model, self.dummy_input, self.traced)
cat_padding_depen = CatPaddingDependency(
self.model, self.dummy_input, self.traced)
name_to_module = {}
for name, module in self.model.named_modules():
name_to_module[name] = module
......@@ -131,11 +135,10 @@ class CatMaskPadding(MaskFix):
# module.bias may be None
b_shape = module.bias.data.size()
b_mask = torch.ones(b_shape).to(device)
self.masks[layer] = {'weight':w_mask, 'bias':b_mask}
self.masks[layer] = {'weight': w_mask, 'bias': b_mask}
return self.masks
class GroupMaskConflict(MaskFix):
def __init__(self, masks, model=None, dummy_input=None, traced=None):
"""
......@@ -154,8 +157,8 @@ class GroupMaskConflict(MaskFix):
the traced model of the target model, is this parameter is not None,
we donnot use the model and dummpy_input to get the trace graph.
"""
super(GroupMaskConflict, self).__init__(masks, model, dummy_input, traced)
super(GroupMaskConflict, self).__init__(
masks, model, dummy_input, traced)
def fix_mask(self):
"""
......@@ -163,7 +166,8 @@ class GroupMaskConflict(MaskFix):
has group dependencies. This function should be called before the
mask inference of the 'speedup' module.
"""
group_depen = GroupDependency(self.model, self.dummy_input, self.traced)
group_depen = GroupDependency(
self.model, self.dummy_input, self.traced)
depens = group_depen.dependency
_logger.info(depens)
for layername in depens:
......@@ -174,8 +178,10 @@ class GroupMaskConflict(MaskFix):
w_mask = self.masks[layername]['weight']
shape = w_mask.size()
count = np.prod(shape[1:])
all_ones = (w_mask.flatten(1).sum(-1) == count).nonzero().squeeze(1).tolist()
all_zeros = (w_mask.flatten(1).sum(-1) == 0).nonzero().squeeze(1).tolist()
all_ones = (w_mask.flatten(1).sum(-1) ==
count).nonzero().squeeze(1).tolist()
all_zeros = (w_mask.flatten(1).sum(-1) ==
0).nonzero().squeeze(1).tolist()
if len(all_ones) + len(all_zeros) < w_mask.size(0):
# In fine-grained pruning, skip this layer
_logger.info('Layers %s using fine-grained pruning', layername)
......@@ -190,7 +196,8 @@ class GroupMaskConflict(MaskFix):
for i in range(group):
_start = step * i
_end = step * (i+1)
_tmp_list = list(filter(lambda x: _start <= x and x < _end, all_zeros))
_tmp_list = list(
filter(lambda x: _start <= x and x < _end, all_zeros))
group_masked.append(_tmp_list)
mini_masked = min([len(x) for x in group_masked])
for gm in group_masked:
......@@ -198,13 +205,13 @@ class GroupMaskConflict(MaskFix):
# To keep the output channel number still being divisible to
# groups, we set the masks of following filters to be zero.
pos = gm[i]
self.masks[layername]['weight'][pos] = torch.ones(shape[1:])
if hasattr(self.masks[layername], 'bias'):
self.masks[layername]['weight'][pos] = torch.ones(
shape[1:])
if 'bias' in self.masks[layername] and self.masks[layername]['bias'] is not None:
self.masks[layername]['bias'][pos] = 1
return self.masks
class ChannelMaskConflict(MaskFix):
def __init__(self, masks, model=None, dummy_input=None, traced=None):
"""
......@@ -223,7 +230,8 @@ class ChannelMaskConflict(MaskFix):
the traced graph of the target model, is this parameter is not None,
we donnot use the model and dummpy_input to get the trace graph.
"""
super(ChannelMaskConflict, self).__init__(masks, model, dummy_input, traced)
super(ChannelMaskConflict, self).__init__(
masks, model, dummy_input, traced)
self.conv_prune_dim = detect_mask_prune_dim(masks, model)
_logger.info('detected conv prune dim: %s', self.conv_prune_dim)
......@@ -235,9 +243,11 @@ class ChannelMaskConflict(MaskFix):
are supported.
"""
if self.conv_prune_dim == 0:
channel_depen = ChannelDependency(self.model, self.dummy_input, self.traced)
channel_depen = ChannelDependency(
self.model, self.dummy_input, self.traced)
else:
channel_depen = InputChannelDependency(self.model, self.dummy_input, self.traced)
channel_depen = InputChannelDependency(
self.model, self.dummy_input, self.traced)
depen_sets = channel_depen.dependency_sets
sum_idx = (1, 2, 3) if self.conv_prune_dim == 0 else (0, 2, 3)
for dset in depen_sets:
......@@ -262,17 +272,29 @@ class ChannelMaskConflict(MaskFix):
channel_masks.append((mask.abs().sum(0) != 0).int())
elif type(m).__name__ == 'BatchNorm2d':
channel_masks.append(mask.int())
elif type(m).__name__ == 'ConvTranspose2d':
# convtranspose have difference memory layout, so that we need create
# a tmp_sum_idx for conv_transpose
tmp_sum_idx = (
0, 2, 3) if self.conv_prune_dim == 0 else (1, 2, 3)
channel_mask = (mask.abs().sum(tmp_sum_idx) != 0).int()
channel_masks.append(channel_mask)
if (channel_mask.sum() * (mask.numel() / mask.shape[1-self.conv_prune_dim])).item() != (mask > 0).sum().item():
fine_grained = True
else:
raise RuntimeError(f'unsupported module type: {type(m).__name__}')
raise RuntimeError(
f'unsupported module type: {type(m).__name__}')
else:
# no mask means not pruned, equivlent to full masks
channel_masks.append(None)
if fine_grained:
_logger.info('fine-grained mask detected, skip solving conflict for this set: %s', dset)
_logger.info(
'fine-grained mask detected, skip solving conflict for this set: %s', dset)
continue
if all(x is None for x in channel_masks):
continue
num_channels_list = [len(x) for x in channel_masks if x is not None]
num_channels_list = [len(x)
for x in channel_masks if x is not None]
# number of channels in same set should be identical
assert len(set(num_channels_list)) == 1
num_channels = num_channels_list[0]
......@@ -284,7 +306,8 @@ class ChannelMaskConflict(MaskFix):
# merge masks with 'or'
merged_channel_mask = channel_masks[0].clone()
for i in range(1, len(channel_masks)):
merged_channel_mask = ((merged_channel_mask + channel_masks[i]) != 0).int()
merged_channel_mask = (
(merged_channel_mask + channel_masks[i]) != 0).int()
merged_index = torch.nonzero(merged_channel_mask, as_tuple=True)[0]
......@@ -305,16 +328,19 @@ class ChannelMaskConflict(MaskFix):
elif type(m).__name__ == 'BatchNorm2d':
new_mask = merged_index.type_as(orig_mask)
else:
raise RuntimeError(f'unsupported module type: {type(m).__name__}')
raise RuntimeError(
f'unsupported module type: {type(m).__name__}')
self.masks[name]['weight'] = new_mask
if 'bias' in self.masks[name] and self.masks[name]['bias'] is not None:
if type(m).__name__ == 'Conv2d':
assert self.conv_prune_dim == 0
self.masks[name]['bias'] = merged_channel_mask.type_as(self.masks[name]['bias'])
self.masks[name]['bias'] = merged_channel_mask.type_as(
self.masks[name]['bias'])
return self.masks
def detect_mask_prune_dim(masks, model):
"""
Detect how the masks of convolutional layers are pruned.
......@@ -358,7 +384,8 @@ def detect_mask_prune_dim(masks, model):
_logger.warning('no multi-dimension masks found.')
return 0
dim0_sparsity, dim1_sparsity = 1. - dim0_preserved / dim0_num, 1. - dim1_preserved / dim1_num
dim0_sparsity, dim1_sparsity = 1. - dim0_preserved / \
dim0_num, 1. - dim1_preserved / dim1_num
_logger.info('dim0 sparsity: %f', dim0_sparsity)
_logger.info('dim1 sparsity: %f', dim1_sparsity)
......
......@@ -4,13 +4,16 @@
import csv
import logging
__all__ = ['ChannelDependency', 'GroupDependency', 'CatPaddingDependency', 'InputChannelDependency']
__all__ = ['ChannelDependency', 'GroupDependency',
'CatPaddingDependency', 'InputChannelDependency']
CONV_TYPE = 'aten::_convolution'
ADD_TYPES = ['aten::add', 'aten::add_']
CAT_TYPE = 'aten::cat'
logger = logging.getLogger('Shape_Dependency')
RESHAPE_OPS = [CAT_TYPE, 'aten::view', 'aten::reshape', 'aten::flatten', 'aten::mean']
RESHAPE_OPS = [CAT_TYPE, 'aten::view',
'aten::reshape', 'aten::flatten', 'aten::mean']
class Dependency:
def __init__(self, model=None, dummy_input=None, traced_model=None):
......@@ -34,6 +37,7 @@ class Dependency:
def export(self, filepath):
raise NotImplementedError
class ChannelDependency(Dependency):
def __init__(self, model=None, dummy_input=None, traced_model=None):
"""
......@@ -50,7 +54,8 @@ class ChannelDependency(Dependency):
if we alreay has the traced graph of the target model, we donnot
need to trace the model again.
"""
super(ChannelDependency, self).__init__(model, dummy_input, traced_model)
super(ChannelDependency, self).__init__(
model, dummy_input, traced_model)
def _get_parent_layers(self, node):
"""
......@@ -71,7 +76,7 @@ class ChannelDependency(Dependency):
queue.append(node)
while queue:
curnode = queue.pop(0)
if curnode.op_type == 'Conv2d' or curnode.op_type == 'Linear':
if curnode.op_type == 'Conv2d' or curnode.op_type == 'Linear' or curnode.op_type == 'ConvTranspose2d':
# find the first met conv
parent_layers.append(curnode.name)
continue
......@@ -119,7 +124,6 @@ class ChannelDependency(Dependency):
for _node in dependency_set:
self.dependency[_node] = dependency_set
def export(self, filepath):
"""
export the channel dependencies as a csv file.
......@@ -185,6 +189,7 @@ class ChannelDependency(Dependency):
d_sets.append(tmp_set)
return d_sets
def reshape_break_channel_dependency(op_node):
"""
The reshape operations such as (reshape, view, flatten) may break
......@@ -213,6 +218,7 @@ def reshape_break_channel_dependency(op_node):
out_channel = out_shape[1]
return in_channel != out_channel
class InputChannelDependency(ChannelDependency):
"""
Some pruners may prune the input channel of the convolutional
......@@ -242,7 +248,8 @@ class InputChannelDependency(ChannelDependency):
if we alreay has the traced graph of the target model, we donnot
need to trace the model again.
"""
super(InputChannelDependency, self).__init__(model, dummy_input, traced_model)
super(InputChannelDependency, self).__init__(
model, dummy_input, traced_model)
def _get_following_convs(self, tensor):
queue = []
......@@ -250,14 +257,14 @@ class InputChannelDependency(ChannelDependency):
queue.extend(self.graph.input_to_node[tensor])
while queue:
curnode = queue.pop(0)
if curnode.op_type == 'Conv2d' or curnode.op_type == 'Linear':
if curnode.op_type == 'Conv2d' or curnode.op_type == 'Linear' or curnode.op_type == 'ConvTranspose2d':
# find the first met conv
key_layers.append(curnode.name)
continue
elif curnode.op_type in RESHAPE_OPS:
# check if the reshape operation will break the channel dependency
if reshape_break_channel_dependency(curnode):
# reshape operations also breaks the dependency relationship
# reshape operations also breaks the dependency relationship
continue
successors = self.graph.find_successors(curnode.unique_name)
successors = [self.graph.name_to_node[name] for name in successors]
......@@ -290,7 +297,8 @@ class InputChannelDependency(ChannelDependency):
class CatPaddingDependency(ChannelDependency):
def __init__(self, model=None, dummy_input=None, traced_model=None):
super(CatPaddingDependency, self).__init__(model, dummy_input, traced_model)
super(CatPaddingDependency, self).__init__(
model, dummy_input, traced_model)
def build_dependency(self):
"""
......@@ -347,6 +355,7 @@ class CatPaddingDependency(ChannelDependency):
row.extend(list(layers))
csv_w.writerow(row)
class GroupDependency(Dependency):
def __init__(self, model=None, dummy_input=None, traced_model=None):
"""
......@@ -388,7 +397,7 @@ class GroupDependency(Dependency):
queue = predeessors
while queue:
curnode = queue.pop(0)
if curnode.op_type == 'Conv2d':
if curnode.op_type == 'Conv2d' or curnode.op_type == 'ConvTranspose2d':
# find the first met conv
parent_layers.append(curnode.name)
continue
......@@ -412,7 +421,8 @@ class GroupDependency(Dependency):
group : int
the number of the groups of the target conv layer.
"""
cpp_conv = list(filter(lambda x: x.kind() == CONV_TYPE, node_group.node_cpps))
cpp_conv = list(filter(lambda x: x.kind() ==
CONV_TYPE, node_group.node_cpps))
assert len(cpp_conv) == 1
cpp_conv = cpp_conv[0]
inputs = list(cpp_conv.inputs())
......@@ -442,12 +452,14 @@ class GroupDependency(Dependency):
filters should be divisible to.
"""
for node in self.graph.nodes_py.nodes_op:
if node.op_type == 'Conv2d':
if node.op_type == 'Conv2d' or node.op_type == 'ConvTranspose2d':
group = self._get_conv_groups(node)
if node.name in self.dependency:
# the conv layer whose group is larger than 1 will require that
# it's number of output channel to be divisible by the number of group.
self.dependency[node.name] = max(self.dependency[node.name], group)
self.dependency[node.name] = max(
self.dependency[node.name], group)
else:
self.dependency[node.name] = group
if group > 1:
......@@ -456,7 +468,8 @@ class GroupDependency(Dependency):
parent_convs = self._get_parent_convs(node)
for parent in parent_convs:
if parent in self.dependency:
self.dependency[parent] = max(self.dependency[parent], group)
self.dependency[parent] = max(
self.dependency[parent], group)
else:
self.dependency[parent] = group
return self.dependency
......@@ -484,6 +497,7 @@ class GroupDependency(Dependency):
for name in self.dependency:
group = self.dependency[name]
csv_w.writerow([name, group])
@property
def dependency_sets(self):
return self.dependency
......@@ -2,6 +2,6 @@
# Licensed under the MIT license.
from .config import *
from .experiment import Experiment, RetiariiExperiment
from .experiment import Experiment
from .nni_client import *
from .base import ExperimentConfig, RetiariiExpConfig
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .local import LocalExperimentConfig
from .common import *
from .local import *
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import copy
import dataclasses
import json
from pathlib import Path
from typing import Any, Dict, Optional, Union
@dataclasses.dataclass(init=False)
class ExperimentConfig:
experiment_name: str
search_space: Any
max_execution_seconds: Optional[int] = None
max_trial_number: Optional[int] = None
trial_concurrency: int
trial_command: str
trial_code_directory: Union[Path, str]
trial_gpu_number: int = 0
extra_config: Optional[Dict[str, str]] = None
_training_service: str
# these values will be used to create template object,
# and the user should overwrite them later.
_placeholder = {
'experiment_name': '_unset_',
'search_space': '_unset_',
'trial_concurrency': -1,
'trial_command': '_unset_',
'trial_code_directory': '_unset_'
}
# simple validation functions
# complex validation logic with special error message should go to `validate()` method instead
_value_range = {
'max_execution_seconds': lambda x: x is None or x > 0,
'max_trial_number': lambda x: x is None or x > 0,
'trial_concurrency': lambda x: x > 0,
'trial_gpu_number': lambda x: x >= 0
}
def __init__(self, **kwargs):
from typing import Any, Dict, Optional, Type, TypeVar
from ruamel import yaml
from . import util
__all__ = ['ConfigBase', 'PathLike']
T = TypeVar('T', bound='ConfigBase')
PathLike = util.PathLike
def _is_missing(obj: Any) -> bool:
return isinstance(obj, type(dataclasses.MISSING))
class ConfigBase:
"""
Base class of config classes.
Subclass may override `_canonical_rules` and `_validation_rules`,
and `validate()` if the logic is complex.
"""
# Rules to convert field value to canonical format.
# The key is field name.
# The value is callable `value -> canonical_value`
# It is not type-hinted so dataclass won't treat it as field
_canonical_rules = {} # type: ignore
# Rules to validate field value.
# The key is field name.
# The value is callable `value -> valid` or `value -> (valid, error_message)`
# The rule will be called with canonical format and is only called when `value` is not None.
# `error_message` is used when `valid` is False.
# It will be prepended with class name and field name in exception message.
_validation_rules = {} # type: ignore
def __init__(self, *, _base_path: Optional[Path] = None, **kwargs):
"""
Initialize a config object and set some fields.
Name of keyword arguments can either be snake_case or camelCase.
They will be converted to snake_case automatically.
If a field is missing and don't have default value, it will be set to `dataclasses.MISSING`.
"""
kwargs = {util.case_insensitive(key): value for key, value in kwargs.items()}
if _base_path is None:
_base_path = Path()
for field in dataclasses.fields(self):
if field.name in kwargs:
setattr(self, field.name, kwargs[field.name])
elif field.default != dataclasses.MISSING:
setattr(self, field.name, field.default)
else:
setattr(self, field.name, type(self)._placeholder[field.name])
value = kwargs.pop(util.case_insensitive(field.name), field.default)
if value is not None and not _is_missing(value):
# relative paths loaded from config file are not relative to pwd
if 'Path' in str(field.type):
value = Path(value).expanduser()
if not value.is_absolute():
value = _base_path / value
# convert nested dict to config type
if isinstance(value, dict):
cls = util.strip_optional(field.type)
if isinstance(cls, type) and issubclass(cls, ConfigBase):
value = cls(**value, _base_path=_base_path)
setattr(self, field.name, value)
if kwargs:
cls = type(self).__name__
fields = ', '.join(kwargs.keys())
raise ValueError(f'{cls}: Unrecognized fields {fields}')
@classmethod
def load(cls: Type[T], path: PathLike) -> T:
"""
Load config from YAML (or JSON) file.
Keys in YAML file can either be camelCase or snake_case.
"""
data = yaml.safe_load(open(path))
if not isinstance(data, dict):
raise ValueError(f'Content of config file {path} is not a dict/object')
return cls(**data, _base_path=Path(path).parent)
def json(self) -> Dict[str, Any]:
"""
Convert config to JSON object.
The keys of returned object will be camelCase.
"""
return dataclasses.asdict(
self.canonical(),
dict_factory = lambda items: dict((util.camel_case(k), v) for k, v in items if v is not None)
)
def canonical(self: T) -> T:
"""
Returns a deep copy, where the fields supporting multiple formats are converted to the canonical format.
Noticeably, relative path may be converted to absolute path.
"""
ret = copy.deepcopy(self)
for field in dataclasses.fields(ret):
key, value = field.name, getattr(ret, field.name)
rule = ret._canonical_rules.get(key)
if rule is not None:
setattr(ret, key, rule(value))
elif isinstance(value, ConfigBase):
setattr(ret, key, value.canonical())
# value will be copied twice, should not be a performance issue anyway
return ret
def validate(self) -> None:
# check existence
for key, placeholder_value in type(self)._placeholder.items():
if getattr(self, key) == placeholder_value:
raise ValueError(f'Field "{key}" is not set')
# TODO: check type
# check value
for key, condition in type(self)._value_range.items():
value = getattr(self, key)
if not condition(value):
raise ValueError(f'Field "{key}" ({repr(value)}) out of range')
# check special fields
if not Path(self.trial_code_directory).is_dir():
raise ValueError(f'Trial code directory "{self.trial_code_directory}" does not exist or is not directory')
def experiment_config_json(self) -> Dict[str, Any]:
# this only contains the common part for most (if not all) training services
# subclasses should override it to provide exclusive fields
return {
'authorName': '_',
'experimentName': self.experiment_name,
'trialConcurrency': self.trial_concurrency,
'maxExecDuration': self.max_execution_seconds or (999 * 24 * 3600),
'maxTrialNum': self.max_trial_number or 99999,
'searchSpace': json.dumps(self.search_space),
'trainingServicePlatform': self._training_service,
'tuner': {'builtinTunerName': '_user_created_'},
**(self.extra_config or {})
}
def cluster_metadata_json(self) -> Any:
# the cluster metadata format is a total mess
# leave it to each subclass before we refactoring nni manager
raise NotImplementedError()
@staticmethod
def create_template(training_service: str) -> 'ExperimentConfig':
for cls in ExperimentConfig.__subclasses__():
for field in dataclasses.fields(cls):
if field.name == '_training_service' and field.default == training_service:
return cls()
raise ValueError(f'Unrecognized training service {training_service}')
class RetiariiExpConfig(ExperimentConfig):
@staticmethod
def create_template(training_service: str) -> 'ExperimentConfig':
for cls in ExperimentConfig.__subclasses__():
for field in dataclasses.fields(cls):
if field.name == '_training_service' and field.default == training_service:
config_obj = cls()
config_obj.search_space = {}
config_obj.trial_command = 'python3 -m nni.retiarii.trial_entry'
# FIXME: expose this field to users
config_obj.trial_code_directory = '../..'
return config_obj
"""
Validate the config object and raise Exception if it's ill-formed.
"""
class_name = type(self).__name__
config = self.canonical()
for field in dataclasses.fields(config):
key, value = field.name, getattr(config, field.name)
# check existence
if _is_missing(value):
raise ValueError(f'{class_name}: {key} is not set')
# check type (TODO)
type_name = str(field.type).replace('typing.', '')
optional = any([
type_name.startswith('Optional['),
type_name.startswith('Union[') and 'NoneType' in type_name,
type_name == 'Any'
])
if value is None:
if optional:
continue
else:
raise ValueError(f'{class_name}: {key} cannot be None')
# check value
rule = config._validation_rules.get(key)
if rule is not None:
try:
result = rule(value)
except Exception:
raise ValueError(f'{class_name}: {key} has bad value {repr(value)}')
if isinstance(result, bool):
if not result:
raise ValueError(f'{class_name}: {key} ({repr(value)}) is out of range')
else:
if not result[0]:
raise ValueError(f'{class_name}: {key} {result[1]}')
# check nested config
if isinstance(value, ConfigBase):
value.validate()
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict, List, Optional, Union
from .base import ConfigBase, PathLike
from . import util
__all__ = [
'ExperimentConfig',
'AlgorithmConfig',
'CustomAlgorithmConfig',
'TrainingServiceConfig',
]
@dataclass(init=False)
class _AlgorithmConfig(ConfigBase):
name: Optional[str] = None
class_name: Optional[str] = None
code_directory: Optional[PathLike] = None
class_args: Optional[Dict[str, Any]] = None
def validate(self):
super().validate()
_validate_algo(self)
@dataclass(init=False)
class AlgorithmConfig(_AlgorithmConfig):
name: str
class_args: Optional[Dict[str, Any]] = None
@dataclass(init=False)
class CustomAlgorithmConfig(_AlgorithmConfig):
class_name: str
class_directory: Optional[PathLike] = None
class_args: Optional[Dict[str, Any]] = None
class TrainingServiceConfig(ConfigBase):
platform: str
@dataclass(init=False)
class ExperimentConfig(ConfigBase):
experiment_name: Optional[str] = None
search_space_file: Optional[PathLike] = None
search_space: Any = None
trial_command: str
trial_code_directory: PathLike = '.'
trial_concurrency: int
trial_gpu_number: int = 0
max_experiment_duration: Optional[str] = None
max_trial_number: Optional[int] = None
nni_manager_ip: Optional[str] = None
use_annotation: bool = False
debug: bool = False
log_level: Optional[str] = None
experiment_working_directory: Optional[PathLike] = None
tuner_gpu_indices: Optional[Union[List[int], str]] = None
tuner: Optional[_AlgorithmConfig] = None
accessor: Optional[_AlgorithmConfig] = None
advisor: Optional[_AlgorithmConfig] = None
training_service: TrainingServiceConfig
def __init__(self, training_service_platform: Optional[str] = None, **kwargs):
super().__init__(**kwargs)
if training_service_platform is not None:
assert 'training_service' not in kwargs
self.training_service = util.training_service_config_factory(training_service_platform)
def validate(self, initialized_tuner: bool = False) -> None:
super().validate()
if initialized_tuner:
_validate_for_exp(self)
else:
_validate_for_nnictl(self)
## End of public API ##
@property
def _canonical_rules(self):
return _canonical_rules
@property
def _validation_rules(self):
return _validation_rules
_canonical_rules = {
'search_space_file': util.canonical_path,
'trial_code_directory': util.canonical_path,
'max_experiment_duration': lambda value: f'{util.parse_time(value)}s' if value is not None else None,
'experiment_working_directory': util.canonical_path,
'tuner_gpu_indices': lambda value: [int(idx) for idx in value.split(',')] if isinstance(value, str) else value
}
_validation_rules = {
'search_space_file': lambda value: (Path(value).is_file(), f'"{value}" does not exist or is not regular file'),
'trial_code_directory': lambda value: (Path(value).is_dir(), f'"{value}" does not exist or is not directory'),
'trial_concurrency': lambda value: value > 0,
'trial_gpu_number': lambda value: value >= 0,
'max_experiment_duration': lambda value: util.parse_time(value) > 0,
'max_trial_number': lambda value: value > 0,
'log_level': lambda value: value in ["trace", "debug", "info", "warning", "error", "fatal"],
'tuner_gpu_indices': lambda value: all(i >= 0 for i in value) and len(value) == len(set(value)),
'training_service': lambda value: (type(value) is not TrainingServiceConfig, 'cannot be abstract base class')
}
def _validate_for_exp(config: ExperimentConfig) -> None:
# validate experiment for nni.Experiment, where tuner is already initialized outside
if config.use_annotation:
raise ValueError('ExperimentConfig: annotation is not supported in this mode')
if util.count(config.search_space, config.search_space_file) != 1:
raise ValueError('ExperimentConfig: search_space and search_space_file must be set one')
if util.count(config.tuner, config.accessor, config.advisor) != 0:
raise ValueError('ExperimentConfig: tuner, accessor, and advisor must not be set in for this mode')
if config.tuner_gpu_indices is not None:
raise ValueError('ExperimentConfig: tuner_gpu_indices is not supported in this mode')
def _validate_for_nnictl(config: ExperimentConfig) -> None:
# validate experiment for normal launching approach
if config.use_annotation:
if util.count(config.search_space, config.search_space_file) != 0:
raise ValueError('ExperimentConfig: search_space and search_space_file must not be set with annotationn')
else:
if util.count(config.search_space, config.search_space_file) != 1:
raise ValueError('ExperimentConfig: search_space and search_space_file must be set one')
if util.count(config.tuner, config.advisor) != 1:
raise ValueError('ExperimentConfig: tuner and advisor must be set one')
def _validate_algo(algo: AlgorithmConfig) -> None:
if algo.name is None:
if algo.class_name is None:
raise ValueError('Missing algorithm name')
if algo.code_directory is not None and not Path(algo.code_directory).is_dir():
raise ValueError(f'code_directory "{algo.code_directory}" does not exist or is not directory')
else:
if algo.class_name is not None or algo.code_directory is not None:
raise ValueError(f'When name is set for registered algorithm, class_name and code_directory cannot be used')
# TODO: verify algorithm installation and class args
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import json
import logging
from pathlib import Path
from tempfile import NamedTemporaryFile
from typing import Any, Dict, List
from .common import ExperimentConfig
from . import util
_logger = logging.getLogger(__name__)
def to_v1_yaml(config: ExperimentConfig, skip_nnictl: bool = False) -> Dict[str, Any]:
config.validate(skip_nnictl)
data = config.json()
ts = data.pop('trainingService')
if ts['platform'] == 'openpai':
ts['platform'] = 'pai'
data['authorName'] = 'N/A'
data['experimentName'] = data.get('experimentName', 'N/A')
data['maxExecDuration'] = data.pop('maxExperimentDuration', '999d')
if data['debug']:
data['versionCheck'] = False
data['maxTrialNum'] = data.pop('maxTrialNumber', 99999)
data['trainingServicePlatform'] = ts['platform']
ss = data.pop('searchSpace', None)
ss_file = data.pop('searchSpaceFile', None)
if ss is not None:
ss_file = NamedTemporaryFile('w', delete=False)
json.dump(ss, ss_file, indent=4)
data['searchSpacePath'] = ss_file.name
elif ss_file is not None:
data['searchSpacePath'] = ss_file
if 'experimentWorkingDirectory' in data:
data['logDir'] = data.pop('experimentWorkingDirectory')
for algo_type in ['tuner', 'assessor', 'advisor']:
algo = data.get(algo_type)
if algo is None:
continue
if algo['name'] is not None: # builtin
algo['builtin' + algo_type.title() + 'Name'] = algo.pop('name')
algo.pop('className', None)
algo.pop('codeDirectory', None)
else:
algo.pop('name', None)
class_name_parts = algo.pop('className').split('.')
algo['codeDir'] = algo.pop('codeDirectory', '') + '/'.join(class_name_parts[:-2])
algo['classFileName'] = class_name_parts[-2] + '.py'
algo['className'] = class_name_parts[-1]
tuner_gpu_indices = _convert_gpu_indices(data.pop('tunerGpuIndices', None))
if tuner_gpu_indices is not None:
data['tuner']['gpuIndicies'] = tuner_gpu_indices
data['trial'] = {
'command': data.pop('trialCommand'),
'codeDir': data.pop('trialCodeDirectory'),
'gpuNum': data.pop('trialGpuNumber', '')
}
if ts['platform'] == 'local':
data['localConfig'] = {
'useActiveGpu': ts['useActiveGpu'],
'maxTrialNumPerGpu': ts['maxTrialNumberPerGpu']
}
if ts.get('gpuIndices') is not None:
data['localConfig']['gpuIndices'] = ','.join(str(idx) for idx in ts['gpuIndices'])
elif ts['platform'] == 'remote':
data['remoteConfig'] = {'reuse': ts['reuseMode']}
data['machineList'] = []
for machine in ts['machineList']:
machine = {
'ip': machine['host'],
'username': machine['user'],
'passwd': machine['password'],
'sshKeyPath': machine['sshKeyFile'],
'passphrase': machine['sshPassphrase'],
'gpuIndices': _convert_gpu_indices(machine['gpuIndices']),
'maxTrialNumPerGpu': machine['maxTrialNumPerGpu'],
'useActiveGpu': machine['useActiveGpu'],
'preCommand': machine['trialPrepareCommand']
}
elif ts['platform'] == 'pai':
data['trial']['cpuNum'] = ts['trialCpuNumber']
data['trial']['memoryMB'] = util.parse_size(ts['trialMemorySize'])
data['trial']['image'] = ts['docker_image']
data['paiConfig'] = {
'userName': ts['username'],
'token': ts['token'],
'host': 'https://' + ts['host'],
'reuse': ts['reuseMode']
}
return data
def _convert_gpu_indices(indices):
return ','.join(str(idx) for idx in indices) if indices is not None else None
def to_cluster_metadata(config: ExperimentConfig) -> List[Dict[str, Any]]:
experiment_config = to_v1_yaml(config, skip_nnictl=True)
ret = []
if config.training_service.platform == 'local':
request_data = dict()
request_data['local_config'] = experiment_config['localConfig']
if request_data['local_config']:
if request_data['local_config'].get('gpuIndices') and isinstance(request_data['local_config'].get('gpuIndices'), int):
request_data['local_config']['gpuIndices'] = str(request_data['local_config'].get('gpuIndices'))
if request_data['local_config'].get('maxTrialNumOnEachGpu'):
request_data['local_config']['maxTrialNumOnEachGpu'] = request_data['local_config'].get('maxTrialNumOnEachGpu')
if request_data['local_config'].get('useActiveGpu'):
request_data['local_config']['useActiveGpu'] = request_data['local_config'].get('useActiveGpu')
ret.append(request_data)
elif config.training_service.platform == 'remote':
request_data = dict()
if experiment_config.get('remoteConfig'):
request_data['remote_config'] = experiment_config['remoteConfig']
else:
request_data['remote_config'] = {'reuse': False}
request_data['machine_list'] = experiment_config['machineList']
if request_data['machine_list']:
for i in range(len(request_data['machine_list'])):
if isinstance(request_data['machine_list'][i].get('gpuIndices'), int):
request_data['machine_list'][i]['gpuIndices'] = str(request_data['machine_list'][i].get('gpuIndices'))
ret.append(request_data)
elif config.training_service.platform == 'openpai':
pai_config_data = dict()
pai_config_data['pai_config'] = experiment_config['paiConfig']
ret.append(pai_config_data)
else:
raise RuntimeError('Unsupported training service ' + config.training_service.platform)
if experiment_config.get('nniManagerIp') is not None:
ret.append({'nni_manager_ip': {'nniManagerIp': experiment_config['nniManagerIp']}})
ret.append({'trial_config': experiment_config['trial']})
return ret
def to_rest_json(config: ExperimentConfig) -> Dict[str, Any]:
experiment_config = to_v1_yaml(config, skip_nnictl=True)
request_data = dict()
request_data['authorName'] = experiment_config['authorName']
request_data['experimentName'] = experiment_config['experimentName']
request_data['trialConcurrency'] = experiment_config['trialConcurrency']
request_data['maxExecDuration'] = util.parse_time(experiment_config['maxExecDuration'])
request_data['maxTrialNum'] = experiment_config['maxTrialNum']
if config.search_space is not None:
request_data['searchSpace'] = json.dumps(config.search_space)
else:
request_data['searchSpace'] = Path(config.search_space_file).read_text()
request_data['trainingServicePlatform'] = experiment_config.get('trainingServicePlatform')
if experiment_config.get('advisor'):
request_data['advisor'] = experiment_config['advisor']
if request_data['advisor'].get('gpuNum'):
_logger.warning('gpuNum is deprecated, please use gpuIndices instead.')
if request_data['advisor'].get('gpuIndices') and isinstance(request_data['advisor'].get('gpuIndices'), int):
request_data['advisor']['gpuIndices'] = str(request_data['advisor'].get('gpuIndices'))
elif experiment_config.get('tuner'):
request_data['tuner'] = experiment_config['tuner']
if request_data['tuner'].get('gpuNum'):
_logger.warning('gpuNum is deprecated, please use gpuIndices instead.')
if request_data['tuner'].get('gpuIndices') and isinstance(request_data['tuner'].get('gpuIndices'), int):
request_data['tuner']['gpuIndices'] = str(request_data['tuner'].get('gpuIndices'))
if 'assessor' in experiment_config:
request_data['assessor'] = experiment_config['assessor']
if request_data['assessor'].get('gpuNum'):
_logger.warning('gpuNum is deprecated, please remove it from your config file.')
else:
request_data['tuner'] = {'builtinTunerName': '_user_created_'}
#debug mode should disable version check
if experiment_config.get('debug') is not None:
request_data['versionCheck'] = not experiment_config.get('debug')
#validate version check
if experiment_config.get('versionCheck') is not None:
request_data['versionCheck'] = experiment_config.get('versionCheck')
if experiment_config.get('logCollection'):
request_data['logCollection'] = experiment_config.get('logCollection')
request_data['clusterMetaData'] = []
if experiment_config['trainingServicePlatform'] == 'local':
request_data['clusterMetaData'].append(
{'key':'codeDir', 'value':experiment_config['trial']['codeDir']})
request_data['clusterMetaData'].append(
{'key': 'command', 'value': experiment_config['trial']['command']})
elif experiment_config['trainingServicePlatform'] == 'remote':
request_data['clusterMetaData'].append(
{'key': 'machine_list', 'value': experiment_config['machineList']})
request_data['clusterMetaData'].append(
{'key': 'trial_config', 'value': experiment_config['trial']})
if not experiment_config.get('remoteConfig'):
# set default value of reuse in remoteConfig to False
experiment_config['remoteConfig'] = {'reuse': False}
request_data['clusterMetaData'].append(
{'key': 'remote_config', 'value': experiment_config['remoteConfig']})
elif experiment_config['trainingServicePlatform'] == 'pai':
request_data['clusterMetaData'].append(
{'key': 'pai_config', 'value': experiment_config['paiConfig']})
request_data['clusterMetaData'].append(
{'key': 'trial_config', 'value': experiment_config['trial']})
elif experiment_config['trainingServicePlatform'] == 'kubeflow':
request_data['clusterMetaData'].append(
{'key': 'kubeflow_config', 'value': experiment_config['kubeflowConfig']})
request_data['clusterMetaData'].append(
{'key': 'trial_config', 'value': experiment_config['trial']})
elif experiment_config['trainingServicePlatform'] == 'frameworkcontroller':
request_data['clusterMetaData'].append(
{'key': 'frameworkcontroller_config', 'value': experiment_config['frameworkcontrollerConfig']})
request_data['clusterMetaData'].append(
{'key': 'trial_config', 'value': experiment_config['trial']})
elif experiment_config['trainingServicePlatform'] == 'aml':
request_data['clusterMetaData'].append(
{'key': 'aml_config', 'value': experiment_config['amlConfig']})
request_data['clusterMetaData'].append(
{'key': 'trial_config', 'value': experiment_config['trial']})
return request_data
......@@ -2,39 +2,25 @@
# Licensed under the MIT license.
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict
from typing import List, Optional, Union
from .base import ExperimentConfig
from .common import TrainingServiceConfig
__all__ = ['LocalConfig']
@dataclass(init=False)
class LocalExperimentConfig(ExperimentConfig):
use_active_gpu: bool = False
class LocalConfig(TrainingServiceConfig):
platform: str = 'local'
use_active_gpu: bool
max_trial_number_per_gpu: int = 1
gpu_indices: Optional[Union[List[int], str]] = None
_training_service: str = 'local'
_canonical_rules = {
'gpu_indices': lambda value: [int(idx) for idx in value.split(',')] if isinstance(value, str) else value
}
def experiment_config_json(self) -> Dict[str, Any]:
ret = super().experiment_config_json()
ret['clusterMetaData'] = [
{
'key': 'codeDir',
'value': str(Path(self.trial_code_directory).resolve())
},
{
'key': 'command',
'value': self.trial_command
}
]
#ret['local_config'] = {
# 'useActiveGpu': self.use_active_gpu
#}
return ret
def cluster_metadata_json(self) -> Any:
return {
'trial_config': {
'command': self.trial_command,
'codeDir': str(Path(self.trial_code_directory).resolve())
}
}
_validation_rules = {
'platform': lambda value: (value == 'local', 'cannot be modified'),
'max_trial_number_per_gpu': lambda value: value > 0,
'gpu_indices': lambda value: all(idx >= 0 for idx in value) and len(value) == len(set(value))
}
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
"""
Miscellaneous utility functions.
"""
import math
import os.path
from pathlib import Path
from typing import Optional, Union
PathLike = Union[Path, str]
def case_insensitive(key: str) -> str:
return key.lower().replace('_', '')
def camel_case(key: str) -> str:
words = key.split('_')
return words[0] + ''.join(word.title() for word in words[1:])
def canonical_path(path: Optional[PathLike]) -> Optional[str]:
# Path.resolve() does not work on Windows when file not exist, so use os.path instead
return os.path.abspath(os.path.expanduser(path)) if path is not None else None
def count(*values) -> int:
return sum(value is not None and value is not False for value in values)
def training_service_config_factory(platform: str): # -> TrainingServiceConfig
from .common import TrainingServiceConfig
for cls in TrainingServiceConfig.__subclasses__():
if cls.platform == platform:
return cls()
raise ValueError(f'Unrecognized platform {platform}')
def strip_optional(type_hint):
return type_hint.__args__[0] if str(type_hint).startswith('typing.Optional[') else type_hint
def parse_time(time: str, target_unit: str = 's') -> int:
return _parse_unit(time.lower(), target_unit, _time_units)
def parse_size(size: str, target_unit: str = 'mb') -> int:
return _parse_unit(size.lower(), target_unit, _size_units)
_time_units = {'d': 24 * 3600, 'h': 3600, 'm': 60, 's': 1}
_size_units = {'gb': 1024 * 1024 * 1024, 'mb': 1024 * 1024, 'kb': 1024}
def _parse_unit(string, target_unit, all_units):
for unit, factor in all_units.items():
if string.endswith(unit):
number = string[:-len(unit)]
value = float(number) * factor
return math.ceil(value / all_units[target_unit])
raise ValueError(f'Unsupported unit in "{string}"')
import atexit
import logging
import socket
from subprocess import Popen
import time
from threading import Thread
from typing import Optional, overload, List, Union, Callable
import time
from typing import Optional, overload
import colorama
import psutil
import nni.runtime.log
from nni.runtime.msg_dispatcher import MsgDispatcher
from nni.tuner import Tuner
from nni.retiarii.integration import RetiariiAdvisor
from nni.retiarii.converter.graph_gen import convert_to_graph
from .config import ExperimentConfig
from . import launcher
from .pipe import Pipe
from . import rest
_logger = logging.getLogger(__name__)
nni.runtime.log.init_logger_experiment()
_logger = logging.getLogger('nni.experiment')
class Experiment:
"""
Controls an NNI experiment.
You may either create a new NNI experiment with construtor and `Experiment.start()`,
# TODO: or control an existing experiment with `Experiment.connect()`.
Create and stop an NNI experiment.
Attributes
----------
......@@ -42,7 +44,7 @@ class Experiment:
Parameters
----------
tuner
A tuner instance. # TODO: accessor / advisor
A tuner instance.
config
Experiment configuration.
"""
......@@ -67,24 +69,24 @@ class Experiment:
A tuner instance.
training_service
Name of training service.
Supported value: "local", "remote", "openpai"/"pai".
Supported value: "local", "remote", "openpai".
"""
...
def __init__(self, tuner: Tuner, config=None, training_service=None):
self.config: ExperimentConfig
self.port: Optional[int] = None
self._dispatcher = MsgDispatcher(tuner, None)
self.tuner: Tuner = tuner
self._proc: Optional[Popen] = None
self._pipe: Optional[Pipe] = None
self._dispatcher: Optional[MsgDispatcher] = None
self._dispatcher_thread: Optional[Thread] = None
if isinstance(config, str):
config, training_service = None, config
if training_service == 'openpai':
training_service = 'pai'
if config is None:
self.config = ExperimentConfig.create_template(training_service)
self.config = ExperimentConfig(training_service)
else:
self.config = config
......@@ -103,6 +105,8 @@ class Experiment:
debug
Whether to start in debug mode.
"""
atexit.register(self.stop)
if debug:
logging.getLogger('nni').setLevel(logging.DEBUG)
......@@ -112,9 +116,20 @@ class Experiment:
self.port = port # port will be None if start up failed
# dispatcher must be created after pipe initialized
# dispatcher must be launched after pipe initialized
# the logic to launch dispatcher in background should be refactored into dispatcher api
Thread(target=self._dispatcher.run).start()
self._dispatcher = MsgDispatcher(self.tuner, None)
self._dispatcher_thread = Thread(target=self._dispatcher.run)
self._dispatcher_thread.start()
ips = [self.config.nni_manager_ip]
for interfaces in psutil.net_if_addrs().values():
for interface in interfaces:
if interface.family == socket.AF_INET:
ips.append(interface.address)
ips = [f'http://{ip}:{port}' for ip in ips if ip]
msg = 'Web UI URLs: ' + colorama.Fore.CYAN + ' '.join(ips)
_logger.info(msg)
# TODO: register experiment management metadata
......@@ -123,27 +138,41 @@ class Experiment:
"""
Stop background experiment.
"""
self._proc.kill()
self._pipe.close()
_logger.info('Stopping experiment...')
atexit.unregister(self.stop)
if self._proc is not None:
self._proc.kill()
if self._pipe is not None:
self._pipe.close()
if self._dispatcher_thread is not None:
self._dispatcher.stopping = True
self._dispatcher_thread.join(timeout=1)
self.port = None
self._proc = None
self._pipe = None
self._dispatcher = None
self._dispatcher_thread = None
def run(self, port: int = 8080, debug: bool = False) -> str:
def run(self, port: int = 8080, debug: bool = False) -> bool:
"""
Run the experiment.
This function will block until experiment finish or error.
Return `True` when experiment done; or return `False` when experiment failed.
"""
self.start(port, debug)
try:
while True:
time.sleep(10)
status = self.get_status()
if status in ['ERROR', 'STOPPED', 'NO_MORE_TRIAL']:
return status
if status == 'STOPPED':
return True
if status == 'ERROR':
return False
finally:
self.stop()
......@@ -153,97 +182,3 @@ class Experiment:
raise RuntimeError('Experiment is not running')
resp = rest.get(self.port, '/check-status')
return resp['status']
class RetiariiExperiment(Experiment):
def __init__(self, base_model: 'nn.Module', trainer: 'BaseTrainer',
applied_mutators: List['Mutator'], strategy: 'BaseStrategy',
tca: 'TraceClassArguments' = None):
self.config: ExperimentConfig = None
self.port: Optional[int] = None
self.base_model = base_model
self.trainer = trainer
self.applied_mutators = applied_mutators
self.strategy = strategy
self.recorded_module_args = tca.recorded_arguments # FIXME: remove this argument
self._dispatcher = RetiariiAdvisor()
self._proc: Optional[Popen] = None
self._pipe: Optional[Pipe] = None
def _start_strategy(self):
import torch
script_module = torch.jit.script(self.base_model)
base_model = convert_to_graph(script_module, self.base_model, self.recorded_module_args)
assert id(self.trainer) in self.recorded_module_args
trainer_config = self.recorded_module_args[id(self.trainer)]
_logger.info('Starting strategy...')
Thread(target=self.strategy.run, args=(base_model, self.applied_mutators, trainer_config)).start()
_logger.info('Strategy started!')
def start(self, config: ExperimentConfig, port: int = 8080, debug: bool = False) -> None:
"""
Start the experiment in background.
This method will raise exception on failure.
If it returns, the experiment should have been successfully started.
Parameters
----------
port
The port of web UI.
debug
Whether to start in debug mode.
"""
if debug:
logging.getLogger('nni').setLevel(logging.DEBUG)
self._proc, self._pipe = launcher.start_experiment(config, port, debug)
assert self._proc is not None
assert self._pipe is not None
self.port = port # port will be None if start up failed
# dispatcher must be created after pipe initialized
# the logic to launch dispatcher in background should be refactored into dispatcher api
Thread(target=self._dispatcher.run).start()
self._start_strategy()
# TODO: register experiment management metadata
def stop(self) -> None:
"""
Stop background experiment.
"""
self._proc.kill()
self._pipe.close()
self.port = None
self._proc = None
self._pipe = None
def run(self, config: ExperimentConfig, port: int = 8080, debug: bool = False) -> str:
"""
Run the experiment.
This function will block until experiment finish or error.
"""
self.start(config, port, debug)
try:
while True:
time.sleep(10)
status = self.get_status()
if status in ['ERROR', 'STOPPED', 'NO_MORE_TRIAL']:
return status
finally:
self.stop()
def get_status(self) -> str:
if self.port is None:
raise RuntimeError('Experiment is not running')
resp = rest.get(self.port, '/check-status')
return resp['status']
import contextlib
import logging
from pathlib import Path
import socket
from subprocess import Popen
......@@ -6,40 +7,46 @@ import sys
import time
from typing import Optional, Tuple
import colorama
import nni.runtime.protocol
import nni_node
from .config import ExperimentConfig
from .config import convert
from . import management
from .pipe import Pipe
from . import rest
_logger = logging.getLogger('nni.experiment')
def start_experiment(config: ExperimentConfig, port: int, debug: bool) -> Tuple[Popen, Pipe]:
pipe = None
proc = None
config.validate()
config.validate(initialized_tuner=True)
_ensure_port_idle(port)
if config._training_service == 'pai':
if config.training_service.platform == 'openpai':
_ensure_port_idle(port + 1, 'OpenPAI requires an additional port')
exp_id = management.generate_experiment_id()
try:
print(f'Creating experiment {exp_id}...')
_logger.info(f'Creating experiment {colorama.Fore.CYAN}{exp_id}')
pipe = Pipe(exp_id)
proc = _start_rest_server(config, port, debug, exp_id, pipe.path)
_logger.info('Connecting IPC pipe...')
pipe_file = pipe.connect()
nni.runtime.protocol._in_file = pipe_file
nni.runtime.protocol._out_file = pipe_file
print('Statring web server...')
_logger.info('Statring web server...')
_check_rest_server(port)
print('Setting up...')
_init_experiment(config, port, debug) # todo: kill on fail
_logger.info('Setting up...')
_init_experiment(config, port, debug)
return proc, pipe
except Exception as e:
print('Create experiment failed')
_logger.error('Create experiment failed')
if proc is not None:
with contextlib.suppress(Exception):
proc.kill()
......@@ -58,9 +65,13 @@ def _ensure_port_idle(port: int, message: Optional[str] = None) -> None:
def _start_rest_server(config: ExperimentConfig, port: int, debug: bool, experiment_id: str, pipe_path: str) -> Popen:
ts = config.training_service.platform
if ts == 'openpai':
ts = 'pai'
args = {
'port': port,
'mode': config._training_service,
'mode': ts,
'experiment_id': experiment_id,
'start_mode': 'new',
'log_level': 'debug' if debug else 'info',
......@@ -77,15 +88,18 @@ def _start_rest_server(config: ExperimentConfig, port: int, debug: bool, experim
return Popen(cmd, cwd=node_dir)
def _check_rest_server(port: int, retry: int = 10) -> None:
for _ in range(retry):
def _check_rest_server(port: int, retry: int = 3) -> None:
for i in range(retry):
with contextlib.suppress(Exception):
rest.get(port, '/check-status')
return
if i > 0:
_logger.warning('Timeout, retry...')
time.sleep(1)
rest.get(port, '/check-status')
def _init_experiment(config: ExperimentConfig, port: int, debug: bool) -> None:
rest.put(port, '/experiment/cluster-metadata', config.cluster_metadata_json())
rest.post(port, '/experiment', config.experiment_config_json())
for cluster_metadata in convert.to_cluster_metadata(config):
rest.put(port, '/experiment/cluster-metadata', cluster_metadata)
rest.post(port, '/experiment', convert.to_rest_json(config))
......@@ -26,7 +26,6 @@ import subprocess
import re
import json
import requests
import yaml
__all__ = [
'ExternalExperiment',
......@@ -260,38 +259,6 @@ class ExternalExperiment:
self._endpoint = 'http://localhost:{}'.format(self._port)
self._exp_id = self.get_experiment_profile()['id']
def tmp_start_retiarii(self, graph_ir, training_approach,
applied_mutators, strategy, exp_config):
# prepare search space file which includes base graph IR and mutators
search_space = {}
search_space['base_model_ir'] = graph_ir
search_space['applied_mutators'] = applied_mutators
search_space['training_approach'] = training_approach
with open('search_space.json', 'w') as f:
json.dump(search_space, f)
# add advisor config to exp_config
exp_config['searchSpacePath'] = 'search_space.json'
exp_config['useAnnotation'] = False
exp_config['advisor'] = {
'codeDir': '.',
'classFileName': 'advisor_entry.py',
'className': 'RetiariiAdvisor',
'classArgs': {
'strategy': '{}.{}'.format(strategy['filename'], strategy['funcname'])
}
}
# add trial config to exp_config
exp_config['trial'] = {
'command': 'python3 -m nni.retiarii.trial_entry',
'codeDir': '../..',
'gpuNum': 0
}
# dump exp_config to nni.yml
with open('nni.yml', 'w') as f:
yaml.dump(exp_config, f)
# start experiment
self.start_experiment('nni.yml')
def start_experiment(self, config_file, port=None, debug=False):
"""
Start an experiment with specified configuration file and connect to it.
......
......@@ -3,7 +3,7 @@ import os
import sys
if sys.platform == 'win32':
import _win32
import _winapi
import msvcrt
class WindowsPipe:
......@@ -11,27 +11,27 @@ if sys.platform == 'win32':
self.path: str = r'\\.\pipe\nni-' + experiment_id
self.file = None
self._handle = _win32.CreateNamedPipe(
self._handle = _winapi.CreateNamedPipe(
self.path,
_win32.PIPE_ACCESS_DUPLEX,
_win32.PIPE_TYPE_MESSAGE | _win32.PIPE_READMODE_MESSAGE | _win32.PIPE_WAIT,
_winapi.PIPE_ACCESS_DUPLEX,
_winapi.PIPE_TYPE_MESSAGE | _winapi.PIPE_READMODE_MESSAGE | _winapi.PIPE_WAIT,
1,
8192,
8192,
0,
_win32.NULL
_winapi.NULL
)
def connect(self) -> BufferedIOBase:
_win32.ConnectNamedPipe(self._handle, _win32.NULL)
fd = msvcrt.open_osfhandle(self._handle)
self.file = os.fdopen(fd, 'rwb')
_winapi.ConnectNamedPipe(self._handle, _winapi.NULL)
fd = msvcrt.open_osfhandle(self._handle, 0)
self.file = os.fdopen(fd, 'w+b')
return self.file
def close(self) -> None:
if self.file is not None:
self.file.close()
_win32.CloseHandle(self._handle)
_winapi.CloseHandle(self._handle)
Pipe = WindowsPipe
......@@ -52,7 +52,7 @@ else:
def connect(self) -> BufferedIOBase:
conn, _ = self._socket.accept()
self.file = conn.makefile('rwb')
self.file = conn.makefile('w+b')
return self.file
def close(self) -> None:
......
......@@ -4,8 +4,11 @@ import logging
from logging import FileHandler, Formatter, Handler, StreamHandler
from pathlib import Path
import sys
import time
from typing import Optional
import colorama
from .env_vars import dispatcher_env_vars, trial_env_vars
......@@ -17,6 +20,8 @@ def init_logger() -> None:
The detection should work in most cases but for `nnictl` and `nni.experiment`.
They will be identified as "standalone" mode and must configure the logger by themselves.
"""
colorama.init()
if dispatcher_env_vars.SDK_PROCESS == 'dispatcher':
_init_logger_dispatcher()
return
......@@ -33,6 +38,15 @@ def init_logger() -> None:
_init_logger_standalone()
def init_logger_experiment() -> None:
"""
Initialize logger for `nni.experiment.Experiment`.
This function will get invoked after `init_logger()`.
"""
formatter.format = _colorful_format
time_format = '%Y-%m-%d %H:%M:%S'
formatter = Formatter(
......@@ -40,14 +54,14 @@ formatter = Formatter(
time_format
)
def _init_logger_dispatcher() -> None:
log_level_map = {
'fatal': logging.CRITICAL,
'error': logging.ERROR,
'warning': logging.WARNING,
'info': logging.INFO,
'debug': logging.DEBUG
'debug': logging.DEBUG,
'trace': 0
}
log_path = _prepare_log_dir(dispatcher_env_vars.NNI_LOG_DIRECTORY) / 'dispatcher.log'
......@@ -93,6 +107,21 @@ def _setup_logger(name: str, handler: Handler, level: int) -> None:
logger.setLevel(level)
logger.propagate = False
def _colorful_format(record):
if record.levelno >= logging.ERROR:
color = colorama.Fore.RED
elif record.levelno >= logging.WARNING:
color = colorama.Fore.YELLOW
elif record.levelno >= logging.INFO:
color = colorama.Fore.GREEN
else:
color = colorama.Fore.BLUE
msg = color + (record.msg % record.args) + colorama.Style.RESET_ALL
time = formatter.formatTime(record, time_format)
if record.levelno < logging.INFO:
return '[{}] {}:{} {}'.format(time, record.threadName, record.name, msg)
else:
return '[{}] {}'.format(time, msg)
class _LogFileWrapper(TextIOBase):
# wrap the logger file so that anything written to it will automatically get formatted
......
......@@ -25,11 +25,11 @@ class MsgDispatcherBase(Recoverable):
"""
def __init__(self):
self.stopping = False
if multi_thread_enabled():
self.pool = ThreadPool()
self.thread_results = []
else:
self.stopping = False
self.default_command_queue = Queue()
self.assessor_command_queue = Queue()
self.default_worker = threading.Thread(target=self.command_queue_worker, args=(self.default_command_queue,))
......@@ -43,11 +43,11 @@ class MsgDispatcherBase(Recoverable):
"""Run the tuner.
This function will never return unless raise.
"""
_logger.info('Start dispatcher')
_logger.info('Dispatcher started')
if dispatcher_env_vars.NNI_MODE == 'resume':
self.load_checkpoint()
while True:
while not self.stopping:
command, data = receive()
if data:
data = json_tricks.loads(data)
......@@ -75,7 +75,7 @@ class MsgDispatcherBase(Recoverable):
self.default_worker.join()
self.assessor_worker.join()
_logger.info('Terminated by NNI manager')
_logger.info('Dispatcher terminiated')
def command_queue_worker(self, command_queue):
"""Process commands in command queues.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment