"src/vscode:/vscode.git/clone" did not exist on "5fc6afe6d0da0a3f188b6ce6586fb56f30e3d398"
Unverified Commit 4784cc6c authored by liuzhe-lz's avatar liuzhe-lz Committed by GitHub
Browse files

Merge pull request #3302 from microsoft/v2.0-merge

Merge branch v2.0 into master (no squash)
parents 25db55ca 349ead41
authorName: default
experimentName: example_cifar10-network-morphism
trialConcurrency: 1
maxExecDuration: 24h
maxTrialNum: 10
#choice: local, remote, pai
trainingServicePlatform: paiYarn
#choice: true, false
useAnnotation: false
tuner:
#choice: TPE, Random, Anneal, Evolution, BatchTuner, NetworkMorphism
#SMAC (SMAC should be installed through nnictl)
builtinTunerName: NetworkMorphism
classArgs:
#choice: maximize, minimize
optimize_mode: maximize
# for now, this tuner only supports cv domain
task: cv
#input image width
input_width: 32
#input image channel
input_channel: 3
#number of classes
n_output_node: 10
trial:
command: python3 cifar10_keras.py
codeDir: .
gpuNum: 1
cpuNum: 1
memoryMB: 8196
#The docker image to run nni job on pai
image: msranni/nni:latest
paiYarnConfig:
#The username to login pai
userName: username
#The password to login pai
passWord: password
#The host of restful server of pai
host: 10.10.10.10
\ No newline at end of file
numpy==1.14.2 numpy==1.19.3
tensorflow==1.15.4 tensorflow==1.15.4
torchvision==0.2.1 torchvision==0.2.1
Keras==2.3.1 Keras==2.3.1
......
authorName: default
experimentName: example_sklearn
trialConcurrency: 1
maxExecDuration: 1h
maxTrialNum: 100
#choice: local, remote, pai
trainingServicePlatform: paiYarn
searchSpacePath: search_space.json
#choice: true, false
useAnnotation: false
tuner:
#choice: TPE, Random, Anneal, Evolution, BatchTuner,MetisTuner
#SMAC (SMAC should be installed through nnictl)
builtinTunerName: TPE
classArgs:
#choice: maximize, minimize
optimize_mode: maximize
trial:
command: python3 main.py
codeDir: .
gpuNum: 0
cpuNum: 1
memoryMB: 8196
#The docker image to run nni job on pai
image: msranni/nni:latest
paiYarnConfig:
#The username to login pai
userName: username
#The password to login pai
passWord: password
#The host of restful server of pai
host: 10.10.10.10
\ No newline at end of file
authorName: default
experimentName: example_sklearn
trialConcurrency: 1
maxExecDuration: 1h
maxTrialNum: 100
#choice: local, remote, pai
trainingServicePlatform: paiYarn
searchSpacePath: search_space.json
#choice: true, false
useAnnotation: false
tuner:
#choice: TPE, Random, Anneal, Evolution, BatchTuner, MetisTuner
#SMAC (SMAC should be installed through nnictl)
builtinTunerName: TPE
classArgs:
#choice: maximize, minimize
optimize_mode: maximize
trial:
command: python3 main.py
codeDir: .
gpuNum: 0
cpuNum: 1
memoryMB: 8196
#The docker image to run nni job on pai
image: msranni/nni:latest
paiYarnConfig:
#The username to login pai
userName: username
#The password to login pai
passWord: password
#The host of restful server of pai
host: 10.10.10.10
\ No newline at end of file
# Copyright (c) Microsoft Corporation. # Copyright (c) Microsoft Corporation.
# Licensed under the MIT license. # Licensed under the MIT license.
__version__ = '999.0.0-developing' try:
from .version import __version__
except ModuleNotFoundError:
__version__ = '999.dev0'
from .runtime.log import init_logger from .runtime.log import init_logger
init_logger() init_logger()
......
...@@ -41,7 +41,7 @@ class NaiveQuantizer(Quantizer): ...@@ -41,7 +41,7 @@ class NaiveQuantizer(Quantizer):
wrapper.module.weight = weight wrapper.module.weight = weight
return weight return weight
def update_ema(biased_ema, value, decay, step): def update_ema(biased_ema, value, decay):
""" """
calculate biased stat and unbiased stat in each step using exponential moving average method calculate biased stat and unbiased stat in each step using exponential moving average method
...@@ -53,16 +53,13 @@ def update_ema(biased_ema, value, decay, step): ...@@ -53,16 +53,13 @@ def update_ema(biased_ema, value, decay, step):
current stat value current stat value
decay : float decay : float
the weight of previous stat value, larger means smoother curve the weight of previous stat value, larger means smoother curve
step : int
current step
Returns Returns
------- -------
float, float float, float
""" """
biased_ema = biased_ema * decay + (1 - decay) * value biased_ema = biased_ema * decay + (1 - decay) * value
unbiased_ema = biased_ema / (1 - decay ** step) # Bias correction return biased_ema
return biased_ema, unbiased_ema
def update_quantization_param(bits, rmin, rmax): def update_quantization_param(bits, rmin, rmax):
...@@ -85,16 +82,10 @@ def update_quantization_param(bits, rmin, rmax): ...@@ -85,16 +82,10 @@ def update_quantization_param(bits, rmin, rmax):
# extend the [min, max] interval to ensure that it contains 0. # extend the [min, max] interval to ensure that it contains 0.
# Otherwise, we would not meet the requirement that 0 be an exactly # Otherwise, we would not meet the requirement that 0 be an exactly
# representable value. # representable value.
if rmin.is_cuda: rmin = torch.min(rmin, torch.Tensor([0]).to(rmin.device))
rmin = torch.min(rmin, torch.Tensor([0]).cuda()) rmax = torch.max(rmax, torch.Tensor([0]).to(rmin.device))
rmax = torch.max(rmax, torch.Tensor([0]).cuda()) qmin = torch.Tensor([0]).to(rmin.device)
qmin = torch.Tensor([0]).cuda() qmax = torch.Tensor([(1 << bits) - 1]).to(rmin.device)
qmax = torch.Tensor([(1 << bits) - 1]).cuda()
else:
rmin = torch.min(rmin, torch.Tensor([0]))
rmax = torch.max(rmax, torch.Tensor([0]))
qmin = torch.Tensor([0])
qmax = torch.Tensor([(1 << bits) - 1])
# First determine the scale. # First determine the scale.
scale = (rmax - rmin) / (qmax - qmin) scale = (rmax - rmin) / (qmax - qmin)
...@@ -103,7 +94,6 @@ def update_quantization_param(bits, rmin, rmax): ...@@ -103,7 +94,6 @@ def update_quantization_param(bits, rmin, rmax):
initial_zero_point = qmin - rmin / scale initial_zero_point = qmin - rmin / scale
# Now we need to nudge the zero point to be an integer # Now we need to nudge the zero point to be an integer
nudged_zero_point = 0
if initial_zero_point < qmin: if initial_zero_point < qmin:
nudged_zero_point = qmin nudged_zero_point = qmin
elif initial_zero_point > qmax: elif initial_zero_point > qmax:
...@@ -121,6 +111,15 @@ def get_bits_length(config, quant_type): ...@@ -121,6 +111,15 @@ def get_bits_length(config, quant_type):
return config["quant_bits"].get(quant_type) return config["quant_bits"].get(quant_type)
class QATGrad(QuantGrad):
@staticmethod
def quant_backward(tensor, grad_output, quant_type, scale, zero_point, qmin, qmax):
tensor_q = QuantGrad._quantize(tensor, scale, zero_point)
mask = (tensor_q < qmin) | (tensor_q > qmax)
grad_output[mask] = 0
return grad_output
class QAT_Quantizer(Quantizer): class QAT_Quantizer(Quantizer):
"""Quantizer defined in: """Quantizer defined in:
Quantization and Training of Neural Networks for Efficient Integer-Arithmetic-Only Inference Quantization and Training of Neural Networks for Efficient Integer-Arithmetic-Only Inference
...@@ -148,6 +147,7 @@ class QAT_Quantizer(Quantizer): ...@@ -148,6 +147,7 @@ class QAT_Quantizer(Quantizer):
types of nn.module you want to apply quantization, eg. 'Conv2d' types of nn.module you want to apply quantization, eg. 'Conv2d'
""" """
super().__init__(model, config_list, optimizer) super().__init__(model, config_list, optimizer)
self.quant_grad = QATGrad
modules_to_compress = self.get_modules_to_compress() modules_to_compress = self.get_modules_to_compress()
self.bound_model.register_buffer("steps", torch.Tensor([1])) self.bound_model.register_buffer("steps", torch.Tensor([1]))
for layer, config in modules_to_compress: for layer, config in modules_to_compress:
...@@ -199,10 +199,8 @@ class QAT_Quantizer(Quantizer): ...@@ -199,10 +199,8 @@ class QAT_Quantizer(Quantizer):
------- -------
Tensor Tensor
""" """
if real_val.is_cuda: op.zero_point = op.zero_point.to(real_val.device)
op.zero_point = op.zero_point.cuda() op.scale = op.scale.to(real_val.device)
op.scale = op.scale.cuda()
transformed_val = op.zero_point + real_val / op.scale transformed_val = op.zero_point + real_val / op.scale
qmin = 0 qmin = 0
qmax = (1 << bits) - 1 qmax = (1 << bits) - 1
...@@ -269,16 +267,17 @@ class QAT_Quantizer(Quantizer): ...@@ -269,16 +267,17 @@ class QAT_Quantizer(Quantizer):
assert output_bits >= 1, "quant bits length should be at least 1" assert output_bits >= 1, "quant bits length should be at least 1"
if quant_start_step > self.bound_model.steps: if quant_start_step > self.bound_model.steps:
module.tracked_min_biased, module.tracked_max_biased = torch.min(output), torch.max(output)
return output return output
# we dont update output quantization parameters in evaluation stage # we dont update output quantization parameters in evaluation stage
if wrapper.training: if wrapper.training:
current_min, current_max = torch.min(output), torch.max(output) current_min, current_max = torch.min(output), torch.max(output)
module.tracked_min_biased, module.tracked_min = update_ema(module.tracked_min_biased, current_min, module.tracked_min_biased = update_ema(module.tracked_min_biased, current_min,
module.ema_decay, self.bound_model.steps) module.ema_decay)
module.tracked_max_biased, module.tracked_max = update_ema(module.tracked_max_biased, current_max, module.tracked_max_biased = update_ema(module.tracked_max_biased, current_max,
module.ema_decay, self.bound_model.steps) module.ema_decay)
module.scale, module.zero_point = update_quantization_param(output_bits, module.tracked_min, module.tracked_max) module.scale, module.zero_point = update_quantization_param(output_bits, module.tracked_min_biased, module.tracked_max_biased)
out = self._quantize(output_bits, module, output) out = self._quantize(output_bits, module, output)
out = self._dequantize(module, out) out = self._dequantize(module, out)
return out return out
...@@ -342,7 +341,7 @@ class DoReFaQuantizer(Quantizer): ...@@ -342,7 +341,7 @@ class DoReFaQuantizer(Quantizer):
class ClipGrad(QuantGrad): class ClipGrad(QuantGrad):
@staticmethod @staticmethod
def quant_backward(tensor, grad_output, quant_type): def quant_backward(tensor, grad_output, quant_type, scale, zero_point, qmin, qmax):
if quant_type == QuantType.QUANT_OUTPUT: if quant_type == QuantType.QUANT_OUTPUT:
grad_output[torch.abs(tensor) > 1] = 0 grad_output[torch.abs(tensor) > 1] = 0
return grad_output return grad_output
......
...@@ -132,7 +132,7 @@ class DartsDiscreteMutator(Mutator): ...@@ -132,7 +132,7 @@ class DartsDiscreteMutator(Mutator):
---------- ----------
model : nn.Module model : nn.Module
The model to apply the mutator. The model to apply the mutator.
parent_mutator : Mutator parent_mutator : nni.nas.pytorch.mutator.Mutator
The mutator that provides ``sample_final`` method, that will be called to get the architecture. The mutator that provides ``sample_final`` method, that will be called to get the architecture.
""" """
def __init__(self, model, parent_mutator): def __init__(self, model, parent_mutator):
......
...@@ -20,7 +20,7 @@ class SPOSSupernetTrainer(Trainer): ...@@ -20,7 +20,7 @@ class SPOSSupernetTrainer(Trainer):
---------- ----------
model : nn.Module model : nn.Module
Model with mutables. Model with mutables.
mutator : Mutator mutator : nni.nas.pytorch.mutator.Mutator
A mutator object that has been initialized with the model. A mutator object that has been initialized with the model.
loss : callable loss : callable
Called with logits and targets. Returns a loss tensor. Called with logits and targets. Returns a loss tensor.
......
...@@ -580,10 +580,15 @@ class QuantType: ...@@ -580,10 +580,15 @@ class QuantType:
""" """
Enum class for quantization type. Enum class for quantization type.
""" """
QUANT_INPUT = 'input' QUANT_INPUT = 0
QUANT_WEIGHT = 'weight' QUANT_WEIGHT = 1
QUANT_OUTPUT = 'output' QUANT_OUTPUT = 2
QType_Dict = {
0: "input",
1: "weight",
2: "output"
}
class QuantGrad(torch.autograd.Function): class QuantGrad(torch.autograd.Function):
""" """
...@@ -628,7 +633,7 @@ class QuantGrad(torch.autograd.Function): ...@@ -628,7 +633,7 @@ class QuantGrad(torch.autograd.Function):
return config["quant_bits"].get(quant_type) return config["quant_bits"].get(quant_type)
@staticmethod @staticmethod
def quant_backward(tensor, grad_output, scale, zero_point, qmin, qmax): def quant_backward(tensor, grad_output, quant_type, scale, zero_point, qmin, qmax):
""" """
This method should be overrided by subclass to provide customized backward function, This method should be overrided by subclass to provide customized backward function,
default implementation is Straight-Through Estimator default implementation is Straight-Through Estimator
...@@ -652,9 +657,6 @@ class QuantGrad(torch.autograd.Function): ...@@ -652,9 +657,6 @@ class QuantGrad(torch.autograd.Function):
tensor tensor
gradient of the input of quantization operation gradient of the input of quantization operation
""" """
tensor_q = QuantGrad._quantize(tensor, scale, zero_point)
mask = (tensor_q < qmin) | (tensor_q > qmax)
grad_output[mask] = 0
return grad_output return grad_output
@staticmethod @staticmethod
...@@ -668,15 +670,21 @@ class QuantGrad(torch.autograd.Function): ...@@ -668,15 +670,21 @@ class QuantGrad(torch.autograd.Function):
else: else:
raise ValueError("unrecognized QuantType.") raise ValueError("unrecognized QuantType.")
bits = QuantGrad.get_bits_length(wrapper.config, quant_type)
qmin, qmax = torch.Tensor([0], device=tensor.device), torch.Tensor([(1 << bits) - 1], device=tensor.device) bits = QuantGrad.get_bits_length(wrapper.config, QType_Dict[quant_type])
ctx.save_for_backward(tensor, wrapper.module.scale, wrapper.module.zero_point, qmin, qmax) qmin, qmax = torch.Tensor([0]).to(tensor.device), torch.Tensor([(1 << bits) - 1]).to(tensor.device)
if hasattr(wrapper.module, 'scale') and hasattr(wrapper.module, 'zero_point'):
scale = wrapper.module.scale
zero_point = wrapper.module.zero_point
else:
scale, zero_point = None, None
ctx.save_for_backward(tensor, torch.Tensor([quant_type]), scale, zero_point, qmin, qmax)
return output return output
@classmethod @classmethod
def backward(cls, ctx, grad_output): def backward(cls, ctx, grad_output):
tensor, scale, zero_point, qmin, qmax = ctx.saved_variables tensor, quant_type, scale, zero_point, qmin, qmax = ctx.saved_variables
output = cls.quant_backward(tensor, grad_output, scale, zero_point, qmin, qmax) output = cls.quant_backward(tensor, grad_output, quant_type, scale, zero_point, qmin, qmax)
return output, None, None, None return output, None, None, None
def _check_weight(module): def _check_weight(module):
......
...@@ -273,7 +273,8 @@ infer_from_inshape = { ...@@ -273,7 +273,8 @@ infer_from_inshape = {
'aten::mean': lambda module_masks, mask, shape: mean_inshape(module_masks, mask, shape), 'aten::mean': lambda module_masks, mask, shape: mean_inshape(module_masks, mask, shape),
'Dropout': lambda module_masks, mask: dropout_inshape(module_masks, mask), 'Dropout': lambda module_masks, mask: dropout_inshape(module_masks, mask),
'Dropout2d': lambda module_masks, mask: dropout_inshape(module_masks, mask), 'Dropout2d': lambda module_masks, mask: dropout_inshape(module_masks, mask),
'aten::dropout': lambda module_masks, mask: dropout_inshape(module_masks, mask) 'aten::dropout': lambda module_masks, mask: dropout_inshape(module_masks, mask),
'aten::detach': lambda module_masks, mask: dropout_inshape(module_masks, mask)
} }
""" """
...@@ -308,7 +309,8 @@ infer_from_outshape = { ...@@ -308,7 +309,8 @@ infer_from_outshape = {
'aten::mean': lambda module_masks, mask, shape: mean_outshape(module_masks, mask, shape), 'aten::mean': lambda module_masks, mask, shape: mean_outshape(module_masks, mask, shape),
'Dropout': lambda module_masks, mask: dropout_outshape(module_masks, mask), 'Dropout': lambda module_masks, mask: dropout_outshape(module_masks, mask),
'Dropout2d': lambda module_masks, mask: dropout_outshape(module_masks, mask), 'Dropout2d': lambda module_masks, mask: dropout_outshape(module_masks, mask),
'aten::dropout': lambda module_masks, mask: dropout_outshape(module_masks, mask) 'aten::dropout': lambda module_masks, mask: dropout_outshape(module_masks, mask),
'aten::detach': lambda module_masks, mask: dropout_outshape(module_masks, mask)
} }
...@@ -889,23 +891,18 @@ def conv2d_mask(module_masks, mask): ...@@ -889,23 +891,18 @@ def conv2d_mask(module_masks, mask):
sum_idx = (1, 2, 3) if dim == 0 else (0, 2, 3) sum_idx = (1, 2, 3) if dim == 0 else (0, 2, 3)
index = torch.nonzero(weight_mask.abs().sum( index = torch.nonzero(weight_mask.abs().sum(
sum_idx) != 0, as_tuple=True)[0] sum_idx) != 0, as_tuple=True)[0]
if len(index) == weight_mask.shape[dim]: # full mask
index = None
if index is None: index = index.long().to(weight_mask.device)
return None, None, None weight_cmask = CoarseMask(num_dim=4)
else: weight_cmask.add_index_mask(dim=dim, index=index)
index = index.long().to(weight_mask.device) bias_cmask = None
weight_cmask = CoarseMask(num_dim=4) if dim == 0 and 'bias' in mask and mask['bias'] is not None:
weight_cmask.add_index_mask(dim=dim, index=index) bias_index = torch.nonzero(mask['bias'], as_tuple=True)[0]
bias_cmask = None assert torch.all(torch.eq(index, bias_index)), \
if dim == 0 and 'bias' in mask and mask['bias'] is not None: "bias mask should be consistent with weight mask"
bias_index = torch.nonzero(mask['bias'], as_tuple=True)[0] bias_cmask = CoarseMask(num_dim=1)
assert torch.all(torch.eq(index, bias_index)), \ bias_cmask.add_index_mask(dim=0, index=bias_index)
"bias mask should be consistent with weight mask" return index, weight_cmask, bias_cmask
bias_cmask = CoarseMask(num_dim=1)
bias_cmask.add_index_mask(dim=0, index=bias_index)
return index, weight_cmask, bias_cmask
index, weight_cmask, bias_cmask = convert_to_coarse_mask( index, weight_cmask, bias_cmask = convert_to_coarse_mask(
mask, dim=conv_prune_dim) mask, dim=conv_prune_dim)
...@@ -960,6 +957,7 @@ def conv2d_inshape(module_masks, mask): ...@@ -960,6 +957,7 @@ def conv2d_inshape(module_masks, mask):
# the same conv layer may be accessed more # the same conv layer may be accessed more
# than once, such as a concat operation. # than once, such as a concat operation.
# mask conflict should be solved by fix_mask_conflict before speedup # mask conflict should be solved by fix_mask_conflict before speedup
assert module_masks.input_mask == mask assert module_masks.input_mask == mask
# shape changes pass through depths wise conv layers # shape changes pass through depths wise conv layers
......
...@@ -31,6 +31,7 @@ def fix_mask_conflict(masks, model=None, dummy_input=None, traced=None): ...@@ -31,6 +31,7 @@ def fix_mask_conflict(masks, model=None, dummy_input=None, traced=None):
# if the input is the path of the mask_file # if the input is the path of the mask_file
assert os.path.exists(masks) assert os.path.exists(masks)
masks = torch.load(masks) masks = torch.load(masks)
assert len(masks) > 0, 'Mask tensor cannot be empty'
# if the user uses the model and dummy_input to trace the model, we # if the user uses the model and dummy_input to trace the model, we
# should get the traced model handly, so that, we only trace the # should get the traced model handly, so that, we only trace the
# model once, GroupMaskConflict and ChannelMaskConflict will reuse # model once, GroupMaskConflict and ChannelMaskConflict will reuse
...@@ -127,6 +128,7 @@ class CatMaskPadding(MaskFix): ...@@ -127,6 +128,7 @@ class CatMaskPadding(MaskFix):
for layer in layers: for layer in layers:
if layer in self.masks: if layer in self.masks:
continue continue
module = name_to_module[layer] module = name_to_module[layer]
w_shape = module.weight.data.size() w_shape = module.weight.data.size()
w_mask = torch.ones(w_shape).to(device) w_mask = torch.ones(w_shape).to(device)
...@@ -136,6 +138,7 @@ class CatMaskPadding(MaskFix): ...@@ -136,6 +138,7 @@ class CatMaskPadding(MaskFix):
b_shape = module.bias.data.size() b_shape = module.bias.data.size()
b_mask = torch.ones(b_shape).to(device) b_mask = torch.ones(b_shape).to(device)
self.masks[layer] = {'weight': w_mask, 'bias': b_mask} self.masks[layer] = {'weight': w_mask, 'bias': b_mask}
return self.masks return self.masks
...@@ -250,6 +253,10 @@ class ChannelMaskConflict(MaskFix): ...@@ -250,6 +253,10 @@ class ChannelMaskConflict(MaskFix):
self.model, self.dummy_input, self.traced) self.model, self.dummy_input, self.traced)
depen_sets = channel_depen.dependency_sets depen_sets = channel_depen.dependency_sets
sum_idx = (1, 2, 3) if self.conv_prune_dim == 0 else (0, 2, 3) sum_idx = (1, 2, 3) if self.conv_prune_dim == 0 else (0, 2, 3)
(_tmp_name, _tmp_tensor) = list(self.masks.items())[0]
device = _tmp_tensor['weight'].device
for dset in depen_sets: for dset in depen_sets:
if len(dset) <= 1: if len(dset) <= 1:
continue continue
...@@ -301,7 +308,7 @@ class ChannelMaskConflict(MaskFix): ...@@ -301,7 +308,7 @@ class ChannelMaskConflict(MaskFix):
for i, dim_mask in enumerate(channel_masks): for i, dim_mask in enumerate(channel_masks):
if dim_mask is None: if dim_mask is None:
channel_masks[i] = torch.ones(num_channels).int() channel_masks[i] = torch.ones(num_channels).int().to(device)
# merge masks with 'or' # merge masks with 'or'
merged_channel_mask = channel_masks[0].clone() merged_channel_mask = channel_masks[0].clone()
......
...@@ -65,15 +65,19 @@ class ExperimentConfig(ConfigBase): ...@@ -65,15 +65,19 @@ class ExperimentConfig(ConfigBase):
tuner: Optional[_AlgorithmConfig] = None tuner: Optional[_AlgorithmConfig] = None
accessor: Optional[_AlgorithmConfig] = None accessor: Optional[_AlgorithmConfig] = None
advisor: Optional[_AlgorithmConfig] = None advisor: Optional[_AlgorithmConfig] = None
training_service: TrainingServiceConfig training_service: Union[TrainingServiceConfig, List[TrainingServiceConfig]]
def __init__(self, training_service_platform: Optional[str] = None, **kwargs): def __init__(self, training_service_platform: Optional[Union[str, List[str]]] = None, **kwargs):
kwargs = util.case_insensitive(kwargs) kwargs = util.case_insensitive(kwargs)
if training_service_platform is not None: if training_service_platform is not None:
assert 'trainingservice' not in kwargs assert 'trainingservice' not in kwargs
kwargs['trainingservice'] = util.training_service_config_factory(training_service_platform) kwargs['trainingservice'] = util.training_service_config_factory(platform = training_service_platform)
elif isinstance(kwargs.get('trainingservice'), dict): elif isinstance(kwargs.get('trainingservice'), (dict, list)):
kwargs['trainingservice'] = util.training_service_config_factory(**kwargs['trainingservice']) # dict means a single training service
# list means hybrid training service
kwargs['trainingservice'] = util.training_service_config_factory(config = kwargs['trainingservice'])
else:
raise RuntimeError('Unsupported Training service configuration!')
super().__init__(**kwargs) super().__init__(**kwargs)
def validate(self, initialized_tuner: bool = False) -> None: def validate(self, initialized_tuner: bool = False) -> None:
......
...@@ -18,8 +18,29 @@ def to_v1_yaml(config: ExperimentConfig, skip_nnictl: bool = False) -> Dict[str, ...@@ -18,8 +18,29 @@ def to_v1_yaml(config: ExperimentConfig, skip_nnictl: bool = False) -> Dict[str,
data = config.json() data = config.json()
ts = data.pop('trainingService') ts = data.pop('trainingService')
if ts['platform'] == 'openpai':
ts['platform'] = 'pai' data['trial'] = {
'command': data.pop('trialCommand'),
'codeDir': data.pop('trialCodeDirectory'),
}
if 'trialGpuNumber' in data:
data['trial']['gpuNum'] = data.pop('trialGpuNumber')
if isinstance(ts, list):
hybrid_names = []
for conf in ts:
if conf['platform'] == 'openpai':
conf['platform'] = 'pai'
hybrid_names.append(conf['platform'])
_handle_training_service(conf, data)
data['trainingServicePlatform'] = 'hybrid'
data['hybridConfig'] = {'trainingServicePlatforms': hybrid_names}
else:
if ts['platform'] == 'openpai':
ts['platform'] = 'pai'
data['trainingServicePlatform'] = ts['platform']
_handle_training_service(ts, data)
data['authorName'] = 'N/A' data['authorName'] = 'N/A'
data['experimentName'] = data.get('experimentName', 'N/A') data['experimentName'] = data.get('experimentName', 'N/A')
...@@ -27,7 +48,7 @@ def to_v1_yaml(config: ExperimentConfig, skip_nnictl: bool = False) -> Dict[str, ...@@ -27,7 +48,7 @@ def to_v1_yaml(config: ExperimentConfig, skip_nnictl: bool = False) -> Dict[str,
if data['debug']: if data['debug']:
data['versionCheck'] = False data['versionCheck'] = False
data['maxTrialNum'] = data.pop('maxTrialNumber', 99999) data['maxTrialNum'] = data.pop('maxTrialNumber', 99999)
data['trainingServicePlatform'] = ts['platform']
ss = data.pop('searchSpace', None) ss = data.pop('searchSpace', None)
ss_file = data.pop('searchSpaceFile', None) ss_file = data.pop('searchSpaceFile', None)
if ss is not None: if ss is not None:
...@@ -58,14 +79,9 @@ def to_v1_yaml(config: ExperimentConfig, skip_nnictl: bool = False) -> Dict[str, ...@@ -58,14 +79,9 @@ def to_v1_yaml(config: ExperimentConfig, skip_nnictl: bool = False) -> Dict[str,
if tuner_gpu_indices is not None: if tuner_gpu_indices is not None:
data['tuner']['gpuIndicies'] = tuner_gpu_indices data['tuner']['gpuIndicies'] = tuner_gpu_indices
data['trial'] = { return data
'command': data.pop('trialCommand'),
'codeDir': data.pop('trialCodeDirectory'),
}
if 'trialGpuNumber' in data:
data['trial']['gpuNum'] = data.pop('trialGpuNumber')
def _handle_training_service(ts, data):
if ts['platform'] == 'local': if ts['platform'] == 'local':
data['localConfig'] = { data['localConfig'] = {
'useActiveGpu': ts.get('useActiveGpu', False), 'useActiveGpu': ts.get('useActiveGpu', False),
...@@ -98,6 +114,9 @@ def to_v1_yaml(config: ExperimentConfig, skip_nnictl: bool = False) -> Dict[str, ...@@ -98,6 +114,9 @@ def to_v1_yaml(config: ExperimentConfig, skip_nnictl: bool = False) -> Dict[str,
data['trial']['image'] = ts['dockerImage'] data['trial']['image'] = ts['dockerImage']
data['trial']['nniManagerNFSMountPath'] = ts['localStorageMountPoint'] data['trial']['nniManagerNFSMountPath'] = ts['localStorageMountPoint']
data['trial']['containerNFSMountPath'] = ts['containerStorageMountPoint'] data['trial']['containerNFSMountPath'] = ts['containerStorageMountPoint']
data['trial']['paiStorageConfigName'] = ts['storageConfigName']
data['trial']['cpuNum'] = ts['trialCpuNumber']
data['trial']['memoryMB'] = ts['trialMemorySize']
data['paiConfig'] = { data['paiConfig'] = {
'userName': ts['username'], 'userName': ts['username'],
'token': ts['token'], 'token': ts['token'],
...@@ -140,8 +159,6 @@ def to_v1_yaml(config: ExperimentConfig, skip_nnictl: bool = False) -> Dict[str, ...@@ -140,8 +159,6 @@ def to_v1_yaml(config: ExperimentConfig, skip_nnictl: bool = False) -> Dict[str,
elif ts['platform'] == 'adl': elif ts['platform'] == 'adl':
data['trial']['image'] = ts['dockerImage'] data['trial']['image'] = ts['dockerImage']
return data
def _convert_gpu_indices(indices): def _convert_gpu_indices(indices):
return ','.join(str(idx) for idx in indices) if indices is not None else None return ','.join(str(idx) for idx in indices) if indices is not None else None
...@@ -175,19 +192,34 @@ def to_cluster_metadata(config: ExperimentConfig) -> List[Dict[str, Any]]: ...@@ -175,19 +192,34 @@ def to_cluster_metadata(config: ExperimentConfig) -> List[Dict[str, Any]]:
experiment_config = to_v1_yaml(config, skip_nnictl=True) experiment_config = to_v1_yaml(config, skip_nnictl=True)
ret = [] ret = []
if config.training_service.platform == 'local': if isinstance(config.training_service, list):
hybrid_conf = dict()
hybrid_conf['hybrid_config'] = experiment_config['hybridConfig']
for conf in config.training_service:
metadata = _get_cluster_metadata(conf.platform, experiment_config)
if metadata is not None:
hybrid_conf.update(metadata)
ret.append(hybrid_conf)
else:
metadata = _get_cluster_metadata(config.training_service.platform, experiment_config)
if metadata is not None:
ret.append(metadata)
if experiment_config.get('nniManagerIp') is not None:
ret.append({'nni_manager_ip': {'nniManagerIp': experiment_config['nniManagerIp']}})
ret.append({'trial_config': experiment_config['trial']})
return ret
def _get_cluster_metadata(platform: str, experiment_config) -> Dict:
if platform == 'local':
request_data = dict() request_data = dict()
request_data['local_config'] = experiment_config['localConfig'] request_data['local_config'] = experiment_config['localConfig']
if request_data['local_config']: if request_data['local_config']:
if request_data['local_config'].get('gpuIndices') and isinstance(request_data['local_config'].get('gpuIndices'), int): if request_data['local_config'].get('gpuIndices') and isinstance(request_data['local_config'].get('gpuIndices'), int):
request_data['local_config']['gpuIndices'] = str(request_data['local_config'].get('gpuIndices')) request_data['local_config']['gpuIndices'] = str(request_data['local_config'].get('gpuIndices'))
if request_data['local_config'].get('maxTrialNumOnEachGpu'): return request_data
request_data['local_config']['maxTrialNumOnEachGpu'] = request_data['local_config'].get('maxTrialNumOnEachGpu')
if request_data['local_config'].get('useActiveGpu'):
request_data['local_config']['useActiveGpu'] = request_data['local_config'].get('useActiveGpu')
ret.append(request_data)
elif config.training_service.platform == 'remote': elif platform == 'remote':
request_data = dict() request_data = dict()
if experiment_config.get('remoteConfig'): if experiment_config.get('remoteConfig'):
request_data['remote_config'] = experiment_config['remoteConfig'] request_data['remote_config'] = experiment_config['remoteConfig']
...@@ -198,31 +230,25 @@ def to_cluster_metadata(config: ExperimentConfig) -> List[Dict[str, Any]]: ...@@ -198,31 +230,25 @@ def to_cluster_metadata(config: ExperimentConfig) -> List[Dict[str, Any]]:
for i in range(len(request_data['machine_list'])): for i in range(len(request_data['machine_list'])):
if isinstance(request_data['machine_list'][i].get('gpuIndices'), int): if isinstance(request_data['machine_list'][i].get('gpuIndices'), int):
request_data['machine_list'][i]['gpuIndices'] = str(request_data['machine_list'][i].get('gpuIndices')) request_data['machine_list'][i]['gpuIndices'] = str(request_data['machine_list'][i].get('gpuIndices'))
ret.append(request_data) return request_data
elif config.training_service.platform == 'openpai': elif platform == 'openpai':
ret.append({'pai_config': experiment_config['paiConfig']}) return {'pai_config': experiment_config['paiConfig']}
elif config.training_service.platform == 'aml': elif platform == 'aml':
ret.append({'aml_config': experiment_config['amlConfig']}) return {'aml_config': experiment_config['amlConfig']}
elif config.training_service.platform == 'kubeflow': elif platform == 'kubeflow':
ret.append({'kubeflow_config': experiment_config['kubeflowConfig']}) return {'kubeflow_config': experiment_config['kubeflowConfig']}
elif config.training_service.platform == 'frameworkcontroller': elif platform == 'frameworkcontroller':
ret.append({'frameworkcontroller_config': experiment_config['frameworkcontrollerConfig']}) return {'frameworkcontroller_config': experiment_config['frameworkcontrollerConfig']}
elif config.training_service.platform == 'adl': elif platform == 'adl':
pass return None
else: else:
raise RuntimeError('Unsupported training service ' + config.training_service.platform) raise RuntimeError('Unsupported training service ' + platform)
if experiment_config.get('nniManagerIp') is not None:
ret.append({'nni_manager_ip': {'nniManagerIp': experiment_config['nniManagerIp']}})
ret.append({'trial_config': experiment_config['trial']})
return ret
def to_rest_json(config: ExperimentConfig) -> Dict[str, Any]: def to_rest_json(config: ExperimentConfig) -> Dict[str, Any]:
experiment_config = to_v1_yaml(config, skip_nnictl=True) experiment_config = to_v1_yaml(config, skip_nnictl=True)
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
# Licensed under the MIT license. # Licensed under the MIT license.
from dataclasses import dataclass from dataclasses import dataclass
from pathlib import Path from pathlib import Path, PurePosixPath
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
from .base import PathLike from .base import PathLike
...@@ -17,6 +17,9 @@ class OpenpaiConfig(TrainingServiceConfig): ...@@ -17,6 +17,9 @@ class OpenpaiConfig(TrainingServiceConfig):
host: str host: str
username: str username: str
token: str token: str
trial_cpu_number: int
trial_memory_size: str
storage_config_name: str
docker_image: str = 'msranni/nni:latest' docker_image: str = 'msranni/nni:latest'
local_storage_mount_point: PathLike local_storage_mount_point: PathLike
container_storage_mount_point: str container_storage_mount_point: str
...@@ -34,7 +37,7 @@ class OpenpaiConfig(TrainingServiceConfig): ...@@ -34,7 +37,7 @@ class OpenpaiConfig(TrainingServiceConfig):
_validation_rules = { _validation_rules = {
'platform': lambda value: (value == 'openpai', 'cannot be modified'), 'platform': lambda value: (value == 'openpai', 'cannot be modified'),
'local_storage_mount_point': lambda value: Path(value).is_dir(), 'local_storage_mount_point': lambda value: Path(value).is_dir(),
'container_storage_mount_point': lambda value: (Path(value).is_absolute(), 'is not absolute'), 'container_storage_mount_point': lambda value: (PurePosixPath(value).is_absolute(), 'is not absolute'),
'openpai_config_file': lambda value: Path(value).is_file() 'openpai_config_file': lambda value: Path(value).is_file()
} }
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
from dataclasses import dataclass from dataclasses import dataclass
from pathlib import Path from pathlib import Path
from typing import List, Optional, Union from typing import List, Optional, Union
import warnings
from .base import ConfigBase, PathLike from .base import ConfigBase, PathLike
from .common import TrainingServiceConfig from .common import TrainingServiceConfig
...@@ -17,7 +18,7 @@ class RemoteMachineConfig(ConfigBase): ...@@ -17,7 +18,7 @@ class RemoteMachineConfig(ConfigBase):
port: int = 22 port: int = 22
user: str user: str
password: Optional[str] = None password: Optional[str] = None
ssh_key_file: Optional[PathLike] = None ssh_key_file: PathLike = None #'~/.ssh/id_rsa'
ssh_passphrase: Optional[str] = None ssh_passphrase: Optional[str] = None
use_active_gpu: bool = False use_active_gpu: bool = False
max_trial_number_per_gpu: int = 1 max_trial_number_per_gpu: int = 1
...@@ -39,6 +40,8 @@ class RemoteMachineConfig(ConfigBase): ...@@ -39,6 +40,8 @@ class RemoteMachineConfig(ConfigBase):
super().validate() super().validate()
if self.password is None and not Path(self.ssh_key_file).is_file(): if self.password is None and not Path(self.ssh_key_file).is_file():
raise ValueError(f'Password is not provided and cannot find SSH key file "{self.ssh_key_file}"') raise ValueError(f'Password is not provided and cannot find SSH key file "{self.ssh_key_file}"')
if self.password:
warnings.warn('Password will be exposed through web UI in plain text. We recommend to use SSH key file.')
@dataclass(init=False) @dataclass(init=False)
class RemoteConfig(TrainingServiceConfig): class RemoteConfig(TrainingServiceConfig):
...@@ -51,6 +54,10 @@ class RemoteConfig(TrainingServiceConfig): ...@@ -51,6 +54,10 @@ class RemoteConfig(TrainingServiceConfig):
kwargs['machinelist'] = util.load_config(RemoteMachineConfig, kwargs.get('machinelist')) kwargs['machinelist'] = util.load_config(RemoteMachineConfig, kwargs.get('machinelist'))
super().__init__(**kwargs) super().__init__(**kwargs)
_canonical_rules = {
'machine_list': lambda value: [config.canonical() for config in value]
}
_validation_rules = { _validation_rules = {
'platform': lambda value: (value == 'remote', 'cannot be modified') 'platform': lambda value: (value == 'remote', 'cannot be modified')
} }
...@@ -8,7 +8,7 @@ Miscellaneous utility functions. ...@@ -8,7 +8,7 @@ Miscellaneous utility functions.
import math import math
import os.path import os.path
from pathlib import Path from pathlib import Path
from typing import Any, Dict, Optional, Union from typing import Any, Dict, Optional, Union, List
PathLike = Union[Path, str] PathLike = Union[Path, str]
...@@ -29,12 +29,26 @@ def canonical_path(path: Optional[PathLike]) -> Optional[str]: ...@@ -29,12 +29,26 @@ def canonical_path(path: Optional[PathLike]) -> Optional[str]:
def count(*values) -> int: def count(*values) -> int:
return sum(value is not None and value is not False for value in values) return sum(value is not None and value is not False for value in values)
def training_service_config_factory(platform: str, **kwargs): # -> TrainingServiceConfig def training_service_config_factory(platform: Union[str, List[str]] = None, config: Union[List, Dict] = None): # -> TrainingServiceConfig
from .common import TrainingServiceConfig from .common import TrainingServiceConfig
for cls in TrainingServiceConfig.__subclasses__(): ts_configs = []
if cls.platform == platform: if platform is not None:
return cls(**kwargs) assert config is None
raise ValueError(f'Unrecognized platform {platform}') platforms = platform if isinstance(platform, list) else [platform]
for cls in TrainingServiceConfig.__subclasses__():
if cls.platform in platforms:
ts_configs.append(cls())
if len(ts_configs) < len(platforms):
raise RuntimeError('There is unrecognized platform!')
else:
assert config is not None
supported_platforms = {cls.platform: cls for cls in TrainingServiceConfig.__subclasses__()}
configs = config if isinstance(config, list) else [config]
for conf in configs:
if conf['platform'] not in supported_platforms:
raise RuntimeError(f'Unrecognized platform {conf["platform"]}')
ts_configs.append(supported_platforms[conf['platform']](**conf))
return ts_configs if len(ts_configs) > 1 else ts_configs[0]
def load_config(Type, value): def load_config(Type, value):
if isinstance(value, list): if isinstance(value, list):
......
import atexit import atexit
import logging import logging
from pathlib import Path
import socket import socket
from subprocess import Popen from subprocess import Popen
from threading import Thread from threading import Thread
import time import time
from typing import Optional, overload from typing import Optional, Union, List, overload
import colorama import colorama
import psutil import psutil
...@@ -15,8 +16,10 @@ from nni.tuner import Tuner ...@@ -15,8 +16,10 @@ from nni.tuner import Tuner
from .config import ExperimentConfig from .config import ExperimentConfig
from . import launcher from . import launcher
from . import management
from .pipe import Pipe from .pipe import Pipe
from . import rest from . import rest
from ..tools.nnictl.command_utils import kill_command
nni.runtime.log.init_logger_experiment() nni.runtime.log.init_logger_experiment()
_logger = logging.getLogger('nni.experiment') _logger = logging.getLogger('nni.experiment')
...@@ -51,7 +54,7 @@ class Experiment: ...@@ -51,7 +54,7 @@ class Experiment:
... ...
@overload @overload
def __init__(self, tuner: Tuner, training_service: str) -> None: def __init__(self, tuner: Tuner, training_service: Union[str, List[str]]) -> None:
""" """
Prepare an experiment, leaving configuration fields to be set later. Prepare an experiment, leaving configuration fields to be set later.
...@@ -69,12 +72,13 @@ class Experiment: ...@@ -69,12 +72,13 @@ class Experiment:
A tuner instance. A tuner instance.
training_service training_service
Name of training service. Name of training service.
Supported value: "local", "remote", "openpai". Supported value: "local", "remote", "openpai", "aml", "kubeflow", "frameworkcontroller", "adl" and hybrid training service.
""" """
... ...
def __init__(self, tuner: Tuner, config=None, training_service=None): def __init__(self, tuner: Tuner, config=None, training_service=None):
self.config: ExperimentConfig self.config: ExperimentConfig
self.id: Optional[str] = None
self.port: Optional[int] = None self.port: Optional[int] = None
self.tuner: Tuner = tuner self.tuner: Tuner = tuner
self._proc: Optional[Popen] = None self._proc: Optional[Popen] = None
...@@ -82,7 +86,7 @@ class Experiment: ...@@ -82,7 +86,7 @@ class Experiment:
self._dispatcher: Optional[MsgDispatcher] = None self._dispatcher: Optional[MsgDispatcher] = None
self._dispatcher_thread: Optional[Thread] = None self._dispatcher_thread: Optional[Thread] = None
if isinstance(config, str): if isinstance(config, (str, list)):
config, training_service = None, config config, training_service = None, config
if config is None: if config is None:
...@@ -107,10 +111,15 @@ class Experiment: ...@@ -107,10 +111,15 @@ class Experiment:
""" """
atexit.register(self.stop) atexit.register(self.stop)
if debug: self.id = management.generate_experiment_id()
logging.getLogger('nni').setLevel(logging.DEBUG)
self._proc, self._pipe = launcher.start_experiment(self.config, port, debug) if self.config.experiment_working_directory is not None:
log_dir = Path(self.config.experiment_working_directory, self.id, 'log')
else:
log_dir = Path.home() / f'nni-experiments/{self.id}/log'
nni.runtime.log.start_experiment_log(self.id, log_dir, debug)
self._proc, self._pipe = launcher.start_experiment(self.id, self.config, port, debug)
assert self._proc is not None assert self._proc is not None
assert self._pipe is not None assert self._pipe is not None
...@@ -118,7 +127,7 @@ class Experiment: ...@@ -118,7 +127,7 @@ class Experiment:
# dispatcher must be launched after pipe initialized # dispatcher must be launched after pipe initialized
# the logic to launch dispatcher in background should be refactored into dispatcher api # the logic to launch dispatcher in background should be refactored into dispatcher api
self._dispatcher = MsgDispatcher(self.tuner, None) self._dispatcher = self._create_dispatcher()
self._dispatcher_thread = Thread(target=self._dispatcher.run) self._dispatcher_thread = Thread(target=self._dispatcher.run)
self._dispatcher_thread.start() self._dispatcher_thread.start()
...@@ -128,32 +137,37 @@ class Experiment: ...@@ -128,32 +137,37 @@ class Experiment:
if interface.family == socket.AF_INET: if interface.family == socket.AF_INET:
ips.append(interface.address) ips.append(interface.address)
ips = [f'http://{ip}:{port}' for ip in ips if ip] ips = [f'http://{ip}:{port}' for ip in ips if ip]
msg = 'Web UI URLs: ' + colorama.Fore.CYAN + ' '.join(ips) msg = 'Web UI URLs: ' + colorama.Fore.CYAN + ' '.join(ips) + colorama.Style.RESET_ALL
_logger.info(msg) _logger.info(msg)
# TODO: register experiment management metadata def _create_dispatcher(self): # overrided by retiarii, temporary solution
return MsgDispatcher(self.tuner, None)
def stop(self) -> None: def stop(self) -> None:
""" """
Stop background experiment. Stop background experiment.
""" """
_logger.info('Stopping experiment...') _logger.info('Stopping experiment, please wait...')
atexit.unregister(self.stop) atexit.unregister(self.stop)
if self.id is not None:
nni.runtime.log.stop_experiment_log(self.id)
if self._proc is not None: if self._proc is not None:
self._proc.kill() kill_command(self._proc.pid)
if self._pipe is not None: if self._pipe is not None:
self._pipe.close() self._pipe.close()
if self._dispatcher_thread is not None: if self._dispatcher_thread is not None:
self._dispatcher.stopping = True self._dispatcher.stopping = True
self._dispatcher_thread.join(timeout=1) self._dispatcher_thread.join(timeout=1)
self.id = None
self.port = None self.port = None
self._proc = None self._proc = None
self._pipe = None self._pipe = None
self._dispatcher = None self._dispatcher = None
self._dispatcher_thread = None self._dispatcher_thread = None
_logger.info('Experiment stopped')
def run(self, port: int = 8080, debug: bool = False) -> bool: def run(self, port: int = 8080, debug: bool = False) -> bool:
...@@ -169,10 +183,12 @@ class Experiment: ...@@ -169,10 +183,12 @@ class Experiment:
while True: while True:
time.sleep(10) time.sleep(10)
status = self.get_status() status = self.get_status()
if status == 'STOPPED': if status == 'DONE' or status == 'STOPPED':
return True return True
if status == 'ERROR': if status == 'ERROR':
return False return False
except KeyboardInterrupt:
_logger.warning('KeyboardInterrupt detected')
finally: finally:
self.stop() self.stop()
......
...@@ -14,33 +14,37 @@ import nni_node ...@@ -14,33 +14,37 @@ import nni_node
from .config import ExperimentConfig from .config import ExperimentConfig
from .config import convert from .config import convert
from . import management
from .pipe import Pipe from .pipe import Pipe
from . import rest from . import rest
from ..tools.nnictl.config_utils import Experiments
_logger = logging.getLogger('nni.experiment') _logger = logging.getLogger('nni.experiment')
def start_experiment(config: ExperimentConfig, port: int, debug: bool) -> Tuple[Popen, Pipe]: def start_experiment(exp_id: str, config: ExperimentConfig, port: int, debug: bool) -> Tuple[Popen, Pipe]:
pipe = None pipe = None
proc = None proc = None
config.validate(initialized_tuner=True) config.validate(initialized_tuner=True)
_ensure_port_idle(port) _ensure_port_idle(port)
if config.training_service.platform == 'openpai': if isinstance(config.training_service, list): # hybrid training service
_ensure_port_idle(port + 1, 'OpenPAI requires an additional port') _ensure_port_idle(port + 1, 'Hybrid training service requires an additional port')
exp_id = management.generate_experiment_id() elif config.training_service.platform in ['remote', 'openpai', 'kubeflow', 'frameworkcontroller', 'adl']:
_ensure_port_idle(port + 1, f'{config.training_service.platform} requires an additional port')
try: try:
_logger.info('Creating experiment %s%s', colorama.Fore.CYAN, exp_id) _logger.info('Creating experiment, Experiment ID: %s', colorama.Fore.CYAN + exp_id + colorama.Style.RESET_ALL)
pipe = Pipe(exp_id) pipe = Pipe(exp_id)
proc = _start_rest_server(config, port, debug, exp_id, pipe.path) start_time, proc = _start_rest_server(config, port, debug, exp_id, pipe.path)
_logger.info('Connecting IPC pipe...') _logger.info('Connecting IPC pipe...')
pipe_file = pipe.connect() pipe_file = pipe.connect()
nni.runtime.protocol._in_file = pipe_file nni.runtime.protocol._in_file = pipe_file
nni.runtime.protocol._out_file = pipe_file nni.runtime.protocol._out_file = pipe_file
_logger.info('Statring web server...') _logger.info('Statring web server...')
_check_rest_server(port) _check_rest_server(port)
platform = 'hybrid' if isinstance(config.training_service, list) else config.training_service.platform
_save_experiment_information(exp_id, port, start_time, platform,
config.experiment_name, proc.pid, config.experiment_working_directory)
_logger.info('Setting up...') _logger.info('Setting up...')
_init_experiment(config, port, debug) _init_experiment(config, port, debug)
return proc, pipe return proc, pipe
...@@ -64,10 +68,13 @@ def _ensure_port_idle(port: int, message: Optional[str] = None) -> None: ...@@ -64,10 +68,13 @@ def _ensure_port_idle(port: int, message: Optional[str] = None) -> None:
raise RuntimeError(f'Port {port} is not idle {message}') raise RuntimeError(f'Port {port} is not idle {message}')
def _start_rest_server(config: ExperimentConfig, port: int, debug: bool, experiment_id: str, pipe_path: str) -> Popen: def _start_rest_server(config: ExperimentConfig, port: int, debug: bool, experiment_id: str, pipe_path: str) -> Tuple[int, Popen]:
ts = config.training_service.platform if isinstance(config.training_service, list):
if ts == 'openpai': ts = 'hybrid'
ts = 'pai' else:
ts = config.training_service.platform
if ts == 'openpai':
ts = 'pai'
args = { args = {
'port': port, 'port': port,
...@@ -85,7 +92,13 @@ def _start_rest_server(config: ExperimentConfig, port: int, debug: bool, experim ...@@ -85,7 +92,13 @@ def _start_rest_server(config: ExperimentConfig, port: int, debug: bool, experim
for arg_key, arg_value in args.items(): for arg_key, arg_value in args.items():
cmd.append('--' + arg_key) cmd.append('--' + arg_key)
cmd.append(str(arg_value)) cmd.append(str(arg_value))
return Popen(cmd, cwd=node_dir)
if sys.platform == 'win32':
from subprocess import CREATE_NEW_PROCESS_GROUP
proc = Popen(cmd, cwd=node_dir, creationflags=CREATE_NEW_PROCESS_GROUP)
else:
proc = Popen(cmd, cwd=node_dir)
return int(time.time() * 1000), proc
def _check_rest_server(port: int, retry: int = 3) -> None: def _check_rest_server(port: int, retry: int = 3) -> None:
...@@ -103,3 +116,8 @@ def _init_experiment(config: ExperimentConfig, port: int, debug: bool) -> None: ...@@ -103,3 +116,8 @@ def _init_experiment(config: ExperimentConfig, port: int, debug: bool) -> None:
for cluster_metadata in convert.to_cluster_metadata(config): for cluster_metadata in convert.to_cluster_metadata(config):
rest.put(port, '/experiment/cluster-metadata', cluster_metadata) rest.put(port, '/experiment/cluster-metadata', cluster_metadata)
rest.post(port, '/experiment', convert.to_rest_json(config)) rest.post(port, '/experiment', convert.to_rest_json(config))
def _save_experiment_information(experiment_id: str, port: int, start_time: int, platform: str, name: str, pid: int, logDir: str) -> None:
experiment_config = Experiments()
experiment_config.add_experiment(experiment_id, port, start_time, platform, name, pid=pid, logDir=logDir)
...@@ -31,7 +31,6 @@ if sys.platform == 'win32': ...@@ -31,7 +31,6 @@ if sys.platform == 'win32':
def close(self) -> None: def close(self) -> None:
if self.file is not None: if self.file is not None:
self.file.close() self.file.close()
_winapi.CloseHandle(self._handle)
Pipe = WindowsPipe Pipe = WindowsPipe
......
...@@ -110,7 +110,7 @@ class BaseMutator(nn.Module): ...@@ -110,7 +110,7 @@ class BaseMutator(nn.Module):
Parameters Parameters
---------- ----------
mutable : LayerChoice mutable : nni.nas.pytorch.mutables.LayerChoice
Module whose forward is called. Module whose forward is called.
args : list of torch.Tensor args : list of torch.Tensor
The arguments of its forward function. The arguments of its forward function.
...@@ -130,7 +130,7 @@ class BaseMutator(nn.Module): ...@@ -130,7 +130,7 @@ class BaseMutator(nn.Module):
Parameters Parameters
---------- ----------
mutable : InputChoice mutable : nni.nas.pytorch.mutables.InputChoice
Mutable that is called. Mutable that is called.
tensor_list : list of torch.Tensor tensor_list : list of torch.Tensor
The arguments mutable is called with. The arguments mutable is called with.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment