Unverified Commit 063d6b74 authored by SparkSnail's avatar SparkSnail Committed by GitHub
Browse files

Merge pull request #3580 from microsoft/v2.2

[do not Squash!] Merge V2.2 back to master
parents 08986c6b e1295888
......@@ -3,3 +3,8 @@ checkpoints
runs
nni_auto_gen_search_space.json
checkpoint.json
_generated_model.py
_generated_model_*.py
_generated_model
generated
lightning_logs
......@@ -14,7 +14,7 @@ from nni.algorithms.nas.pytorch.pdarts import PdartsTrainer
# prevent it to be reordered.
if True:
sys.path.append('../darts')
sys.path.append('../../oneshot/darts')
from utils import accuracy
from model import CNN
import datasets
......
authorName: default
experimentName: example_mnist_pytorch
trialConcurrency: 1
maxExecDuration: 1h
maxTrialNum: 10
#choice: local, remote, pai
trainingServicePlatform: local
searchSpacePath: search_space.json
#choice: true, false
useAnnotation: false
tuner:
#choice: TPE, Random, Anneal, Evolution, BatchTuner, MetisTuner, GPTuner
#SMAC (SMAC should be installed through nnictl)
builtinTunerName: TPE
classArgs:
#choice: maximize, minimize
optimize_mode: maximize
trial:
command: python3 mnist_tensorboard.py
codeDir: .
gpuNum: 0
"""
A deep MNIST classifier using convolutional layers.
This file is a modification of the official pytorch mnist example:
https://github.com/pytorch/examples/blob/master/mnist/main.py
"""
import os
import argparse
import logging
import nni
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter
from nni.utils import merge_parameter
from torchvision import datasets, transforms
logger = logging.getLogger('mnist_AutoML')
writer = SummaryWriter(log_dir=os.path.join(os.environ['NNI_OUTPUT_DIR'], 'tensorboard'))
class Net(nn.Module):
def __init__(self, hidden_size):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 20, 5, 1)
self.conv2 = nn.Conv2d(20, 50, 5, 1)
self.fc1 = nn.Linear(4*4*50, hidden_size)
self.fc2 = nn.Linear(hidden_size, 10)
def forward(self, x):
x = F.relu(self.conv1(x))
x = F.max_pool2d(x, 2, 2)
x = F.relu(self.conv2(x))
x = F.max_pool2d(x, 2, 2)
x = x.view(-1, 4*4*50)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return F.log_softmax(x, dim=1)
def train(args, model, device, train_loader, optimizer, epoch):
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
if (args['batch_num'] is not None) and batch_idx >= args['batch_num']:
break
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = F.nll_loss(output, target)
writer.add_scalar('Loss/train', loss, epoch)
loss.backward()
optimizer.step()
if batch_idx % args['log_interval'] == 0:
logger.info('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, batch_idx * len(data), len(train_loader.dataset),
100. * batch_idx / len(train_loader), loss.item()))
def test(args, model, device, test_loader):
model.eval()
test_loss = 0
correct = 0
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
# sum up batch loss
test_loss += F.nll_loss(output, target, reduction='sum').item()
# get the index of the max log-probability
pred = output.argmax(dim=1, keepdim=True)
correct += pred.eq(target.view_as(pred)).sum().item()
test_loss /= len(test_loader.dataset)
accuracy = 100. * correct / len(test_loader.dataset)
logger.info('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
test_loss, correct, len(test_loader.dataset), accuracy))
return accuracy
def main(args):
use_cuda = not args['no_cuda'] and torch.cuda.is_available()
torch.manual_seed(args['seed'])
device = torch.device("cuda" if use_cuda else "cpu")
kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
data_dir = args['data_dir']
train_loader = torch.utils.data.DataLoader(
datasets.MNIST(data_dir, train=True, download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])),
batch_size=args['batch_size'], shuffle=True, **kwargs)
test_loader = torch.utils.data.DataLoader(
datasets.MNIST(data_dir, train=False, transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])),
batch_size=1000, shuffle=True, **kwargs)
hidden_size = args['hidden_size']
model = Net(hidden_size=hidden_size).to(device)
optimizer = optim.SGD(model.parameters(), lr=args['lr'],
momentum=args['momentum'])
for epoch in range(1, args['epochs'] + 1):
train(args, model, device, train_loader, optimizer, epoch)
test_acc = test(args, model, device, test_loader)
writer.add_scalar('Accuracy/test', test_acc, epoch)
# report intermediate result
nni.report_intermediate_result(test_acc)
logger.debug('test accuracy %g', test_acc)
logger.debug('Pipe send intermediate result done.')
writer.close()
# report final result
nni.report_final_result(test_acc)
logger.debug('Final result is %g', test_acc)
logger.debug('Send final result done.')
def get_params():
# Training settings
parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
parser.add_argument("--data_dir", type=str,
default='./data', help="data directory")
parser.add_argument('--batch_size', type=int, default=64, metavar='N',
help='input batch size for training (default: 64)')
parser.add_argument("--batch_num", type=int, default=None)
parser.add_argument("--hidden_size", type=int, default=512, metavar='N',
help='hidden layer size (default: 512)')
parser.add_argument('--lr', type=float, default=0.01, metavar='LR',
help='learning rate (default: 0.01)')
parser.add_argument('--momentum', type=float, default=0.5, metavar='M',
help='SGD momentum (default: 0.5)')
parser.add_argument('--epochs', type=int, default=10, metavar='N',
help='number of epochs to train (default: 10)')
parser.add_argument('--seed', type=int, default=1, metavar='S',
help='random seed (default: 1)')
parser.add_argument('--no_cuda', action='store_true', default=False,
help='disables CUDA training')
parser.add_argument('--log_interval', type=int, default=1000, metavar='N',
help='how many batches to wait before logging training status')
args, _ = parser.parse_known_args()
return args
if __name__ == '__main__':
try:
# get parameters form tuner
tuner_params = nni.get_next_parameter()
logger.debug(tuner_params)
params = vars(merge_parameter(get_params(), tuner_params))
print(params)
main(params)
except Exception as exception:
logger.exception(exception)
raise
......@@ -32,6 +32,13 @@ def main():
if exp_params.get('deprecated', {}).get('multiThread'):
enable_multi_thread()
if 'trainingServicePlatform' in exp_params: # config schema is v1
from types import SimpleNamespace
from .experiment.config.convert import convert_algo
for algo_type in ['tuner', 'assessor', 'advisor']:
if algo_type in exp_params:
exp_params[algo_type] = convert_algo(algo_type, exp_params, SimpleNamespace()).json()
if exp_params.get('advisor') is not None:
# advisor is enabled and starts to run
_run_advisor(exp_params)
......
......@@ -7,6 +7,7 @@ from prettytable import PrettyTable
import torch
import torch.nn as nn
from torch.nn.utils.rnn import PackedSequence
from nni.compression.pytorch.compressor import PrunerModuleWrapper
......@@ -32,21 +33,27 @@ class ModelProfiler:
for reference, please see ``self.ops``.
mode:
the mode of how to collect information. If the mode is set to `default`,
only the information of convolution and linear will be collected.
only the information of convolution, linear and rnn modules will be collected.
If the mode is set to `full`, other operations will also be collected.
"""
self.ops = {
nn.Conv1d: self._count_convNd,
nn.Conv2d: self._count_convNd,
nn.Conv3d: self._count_convNd,
nn.Linear: self._count_linear
nn.ConvTranspose1d: self._count_convNd,
nn.ConvTranspose2d: self._count_convNd,
nn.ConvTranspose3d: self._count_convNd,
nn.Linear: self._count_linear,
nn.RNNCell: self._count_rnn_cell,
nn.GRUCell: self._count_gru_cell,
nn.LSTMCell: self._count_lstm_cell,
nn.RNN: self._count_rnn,
nn.GRU: self._count_gru,
nn.LSTM: self._count_lstm
}
self._count_bias = False
if mode == 'full':
self.ops.update({
nn.ConvTranspose1d: self._count_convNd,
nn.ConvTranspose2d: self._count_convNd,
nn.ConvTranspose3d: self._count_convNd,
nn.BatchNorm1d: self._count_bn,
nn.BatchNorm2d: self._count_bn,
nn.BatchNorm3d: self._count_bn,
......@@ -86,7 +93,7 @@ class ModelProfiler:
def _count_convNd(self, m, x, y):
cin = m.in_channels
kernel_ops = m.weight.size()[2] * m.weight.size()[3]
kernel_ops = torch.zeros(m.weight.size()[2:]).numel()
output_size = torch.zeros(y.size()[2:]).numel()
cout = y.size()[1]
......@@ -156,13 +163,125 @@ class ModelProfiler:
return self._get_result(m, total_ops)
def _count_cell_flops(self, input_size, hidden_size, cell_type):
# h' = \tanh(W_{ih} x + b_{ih} + W_{hh} h + b_{hh})
total_ops = hidden_size * (input_size + hidden_size) + hidden_size
if self._count_bias:
total_ops += hidden_size * 2
if cell_type == 'rnn':
return total_ops
if cell_type == 'gru':
# r = \sigma(W_{ir} x + b_{ir} + W_{hr} h + b_{hr}) \\
# z = \sigma(W_{iz} x + b_{iz} + W_{hz} h + b_{hz}) \\
# n = \tanh(W_{in} x + b_{in} + r * (W_{hn} h + b_{hn})) \\
total_ops *= 3
# r hadamard : r * (~)
total_ops += hidden_size
# h' = (1 - z) * n + z * h
# hadamard hadamard add
total_ops += hidden_size * 3
elif cell_type == 'lstm':
# i = \sigma(W_{ii} x + b_{ii} + W_{hi} h + b_{hi}) \\
# f = \sigma(W_{if} x + b_{if} + W_{hf} h + b_{hf}) \\
# o = \sigma(W_{io} x + b_{io} + W_{ho} h + b_{ho}) \\
# g = \tanh(W_{ig} x + b_{ig} + W_{hg} h + b_{hg}) \\
total_ops *= 4
# c' = f * c + i * g
# hadamard hadamard add
total_ops += hidden_size * 3
# h' = o * \tanh(c')
total_ops += hidden_size
return total_ops
def _count_rnn_cell(self, m, x, y):
total_ops = self._count_cell_flops(m.input_size, m.hidden_size, 'rnn')
batch_size = x[0].size(0)
total_ops *= batch_size
return self._get_result(m, total_ops)
def _count_gru_cell(self, m, x, y):
total_ops = self._count_cell_flops(m.input_size, m.hidden_size, 'gru')
batch_size = x[0].size(0)
total_ops *= batch_size
return self._get_result(m, total_ops)
def _count_lstm_cell(self, m, x, y):
total_ops = self._count_cell_flops(m.input_size, m.hidden_size, 'lstm')
batch_size = x[0].size(0)
total_ops *= batch_size
return self._get_result(m, total_ops)
def _get_bsize_nsteps(self, m, x):
if isinstance(x[0], PackedSequence):
batch_size = torch.max(x[0].batch_sizes)
num_steps = x[0].batch_sizes.size(0)
else:
if m.batch_first:
batch_size = x[0].size(0)
num_steps = x[0].size(1)
else:
batch_size = x[0].size(1)
num_steps = x[0].size(0)
return batch_size, num_steps
def _count_rnn_module(self, m, x, y, module_name):
input_size = m.input_size
hidden_size = m.hidden_size
num_layers = m.num_layers
batch_size, num_steps = self._get_bsize_nsteps(m, x)
total_ops = self._count_cell_flops(input_size, hidden_size, module_name)
for _ in range(num_layers - 1):
if m.bidirectional:
cell_flops = self._count_cell_flops(hidden_size * 2, hidden_size, module_name) * 2
else:
cell_flops = self._count_cell_flops(hidden_size, hidden_size,module_name)
total_ops += cell_flops
total_ops *= num_steps
total_ops *= batch_size
return total_ops
def _count_rnn(self, m, x, y):
total_ops = self._count_rnn_module(m, x, y, 'rnn')
return self._get_result(m, total_ops)
def _count_gru(self, m, x, y):
total_ops = self._count_rnn_module(m, x, y, 'gru')
return self._get_result(m, total_ops)
def _count_lstm(self, m, x, y):
total_ops = self._count_rnn_module(m, x, y, 'lstm')
return self._get_result(m, total_ops)
def count_module(self, m, x, y, name):
# assume x is tuple of single tensor
result = self.ops[type(m)](m, x, y)
output_size = y[0].size() if isinstance(y, tuple) else y.size()
total_result = {
'name': name,
'input_size': tuple(x[0].size()),
'output_size': tuple(y.size()),
'output_size': tuple(output_size),
'module_type': type(m).__name__,
**result
}
......@@ -279,10 +398,6 @@ def count_flops_params(model, x, custom_ops=None, verbose=True, mode='default'):
model(*x)
# restore origin status
for name, m in model.named_modules():
if hasattr(m, 'weight_mask'):
delattr(m, 'weight_mask')
model.train(training).to(original_device)
for handler in handler_collection:
handler.remove()
......
......@@ -15,6 +15,7 @@ class AmlConfig(TrainingServiceConfig):
workspace_name: str
compute_target: str
docker_image: str = 'msranni/nni:latest'
max_trial_number_per_gpu: int = 1
_validation_rules = {
'platform': lambda value: (value == 'aml', 'cannot be modified')
......
......@@ -82,6 +82,7 @@ class ConfigBase:
Convert config to JSON object.
The keys of returned object will be camelCase.
"""
self.validate()
return dataclasses.asdict(
self.canonical(),
dict_factory=lambda items: dict((util.camel_case(k), v) for k, v in items if v is not None)
......
......@@ -98,6 +98,13 @@ class ExperimentConfig(ConfigBase):
if isinstance(kwargs.get(algo_type), dict):
setattr(self, algo_type, _AlgorithmConfig(**kwargs.pop(algo_type)))
def canonical(self):
ret = super().canonical()
if isinstance(ret.training_service, list):
for i, ts in enumerate(ret.training_service):
ret.training_service[i] = ts.canonical()
return ret
def validate(self, initialized_tuner: bool = False) -> None:
super().validate()
if initialized_tuner:
......
......@@ -45,31 +45,8 @@ def to_v2(v1) -> ExperimentConfig:
_move_field(v1_trial, v2, 'gpuNum', 'trial_gpu_number')
for algo_type in ['tuner', 'assessor', 'advisor']:
if algo_type not in v1:
continue
v1_algo = v1.pop(algo_type)
builtin_name = v1_algo.pop(f'builtin{algo_type.title()}Name', None)
class_args = v1_algo.pop('classArgs', None)
if builtin_name is not None:
v2_algo = AlgorithmConfig(name=builtin_name, class_args=class_args)
else:
class_directory = util.canonical_path(v1_algo.pop('codeDir'))
class_file_name = v1_algo.pop('classFileName')
assert class_file_name.endswith('.py')
class_name = class_file_name[:-3] + '.' + v1_algo.pop('className')
v2_algo = CustomAlgorithmConfig(
class_name=class_name,
class_directory=class_directory,
class_args=class_args
)
setattr(v2, algo_type, v2_algo)
_deprecate(v1_algo, v2, 'includeIntermediateResults')
_move_field(v1_algo, v2, 'gpuIndices', 'tuner_gpu_indices')
assert not v1_algo, v1_algo
if algo_type in v1:
convert_algo(algo_type, v1, v2)
ts = v2.training_service
......@@ -134,7 +111,7 @@ def to_v2(v1) -> ExperimentConfig:
_move_field(aml_config, ts, 'resourceGroup', 'resource_group')
_move_field(aml_config, ts, 'workspaceName', 'workspace_name')
_move_field(aml_config, ts, 'computeTarget', 'compute_target')
_deprecate(aml_config, v2, 'maxTrialNumPerGpu')
_move_field(aml_config, ts, 'maxTrialNumPerGpu', 'max_trial_number_per_gpu')
_deprecate(aml_config, v2, 'useActiveGpu')
assert not aml_config, aml_config
......@@ -259,3 +236,31 @@ def _deprecate(v1, v2, key):
if v2._deprecated is None:
v2._deprecated = {}
v2._deprecated[key] = v1.pop(key)
def convert_algo(algo_type, v1, v2):
if algo_type not in v1:
return None
v1_algo = v1.pop(algo_type)
builtin_name = v1_algo.pop(f'builtin{algo_type.title()}Name', None)
class_args = v1_algo.pop('classArgs', None)
if builtin_name is not None:
v2_algo = AlgorithmConfig(name=builtin_name, class_args=class_args)
else:
class_directory = util.canonical_path(v1_algo.pop('codeDir'))
class_file_name = v1_algo.pop('classFileName')
assert class_file_name.endswith('.py')
class_name = class_file_name[:-3] + '.' + v1_algo.pop('className')
v2_algo = CustomAlgorithmConfig(
class_name=class_name,
class_directory=class_directory,
class_args=class_args
)
setattr(v2, algo_type, v2_algo)
_deprecate(v1_algo, v2, 'includeIntermediateResults')
_move_field(v1_algo, v2, 'gpuIndices', 'tuner_gpu_indices')
assert not v1_algo, v1_algo
return v2_algo
......@@ -19,7 +19,6 @@ from . import management
from . import rest
from ..tools.nnictl.command_utils import kill_command
nni.runtime.log.init_logger_experiment()
_logger = logging.getLogger('nni.experiment')
......@@ -40,7 +39,7 @@ class Experiment:
"""
Prepare an experiment.
Use `Experiment.start()` to launch it.
Use `Experiment.run()` to launch it.
Parameters
----------
......@@ -60,7 +59,7 @@ class Experiment:
experiment.config.trial_command = 'python3 trial.py'
experiment.config.machines.append(RemoteMachineConfig(ip=..., user_name=...))
...
experiment.start(8080)
experiment.run(8080)
Parameters
----------
......@@ -71,6 +70,8 @@ class Experiment:
...
def __init__(self, config=None, training_service=None):
nni.runtime.log.init_logger_experiment()
self.config: Optional[ExperimentConfig] = None
self.id: Optional[str] = None
self.port: Optional[int] = None
......@@ -149,27 +150,30 @@ class Experiment:
self._proc = None
_logger.info('Experiment stopped')
def run(self, port: int = 8080, debug: bool = False) -> bool:
def run(self, port: int = 8080, wait_completion: bool = True, debug: bool = False) -> bool:
"""
Run the experiment.
This function will block until experiment finish or error.
If wait_completion is True, this function will block until experiment finish or error.
Return `True` when experiment done; or return `False` when experiment failed.
Else if wait_completion is False, this function will non-block and return None immediately.
"""
self.start(port, debug)
try:
while True:
time.sleep(10)
status = self.get_status()
if status == 'DONE' or status == 'STOPPED':
return True
if status == 'ERROR':
return False
except KeyboardInterrupt:
_logger.warning('KeyboardInterrupt detected')
finally:
self.stop()
if wait_completion:
try:
while True:
time.sleep(10)
status = self.get_status()
if status == 'DONE' or status == 'STOPPED':
return True
if status == 'ERROR':
return False
except KeyboardInterrupt:
_logger.warning('KeyboardInterrupt detected')
finally:
self.stop()
@classmethod
def connect(cls, port: int):
......@@ -194,7 +198,7 @@ class Experiment:
return experiment
@classmethod
def resume(cls, experiment_id: str, port: int, wait_completion: bool = True, debug: bool = False):
def resume(cls, experiment_id: str, port: int = 8080, wait_completion: bool = True, debug: bool = False):
"""
Resume a stopped experiment.
......@@ -202,17 +206,22 @@ class Experiment:
----------
experiment_id
The stopped experiment id.
port
The port of web UI.
wait_completion
If true, run in the foreground. If false, run in the background.
debug
Whether to start in debug mode.
"""
experiment = Experiment()
experiment.id = experiment_id
experiment.mode = 'resume'
if wait_completion:
experiment.run(port, debug)
else:
experiment.start(port, debug)
experiment.run(port=port, wait_completion=wait_completion, debug=debug)
if not wait_completion:
return experiment
@classmethod
def view(cls, experiment_id: str, port: int, wait_completion: bool = True, debug: bool = False):
def view(cls, experiment_id: str, port: int = 8080, non_blocking: bool = False):
"""
View a stopped experiment.
......@@ -220,14 +229,26 @@ class Experiment:
----------
experiment_id
The stopped experiment id.
port
The port of web UI.
non_blocking
If false, run in the foreground. If true, run in the background.
"""
debug = False
experiment = Experiment()
experiment.id = experiment_id
experiment.mode = 'view'
if wait_completion:
experiment.run(port, debug)
else:
experiment.start(port, debug)
experiment.start(port=port, debug=debug)
if non_blocking:
return experiment
else:
try:
while True:
time.sleep(10)
except KeyboardInterrupt:
_logger.warning('KeyboardInterrupt detected')
finally:
experiment.stop()
def get_status(self) -> str:
"""
......
......@@ -43,10 +43,9 @@ def start_experiment(exp_id: str, config: ExperimentConfig, port: int, debug: bo
_check_rest_server(port)
platform = 'hybrid' if isinstance(config.training_service, list) else config.training_service.platform
_save_experiment_information(exp_id, port, start_time, platform,
config.experiment_name, proc.pid, config.experiment_working_directory)
if mode != 'view':
_logger.info('Setting up...')
rest.post(port, '/experiment', config.json())
config.experiment_name, proc.pid, str(config.experiment_working_directory))
_logger.info('Setting up...')
rest.post(port, '/experiment', config.json())
return proc
except Exception as e:
......@@ -116,7 +115,8 @@ def _start_rest_server(config: ExperimentConfig, port: int, debug: bool, experim
'mode': ts,
'experiment_id': experiment_id,
'start_mode': mode,
'log_level': 'debug' if debug else 'info',
'log_dir': config.experiment_working_directory,
'log_level': 'debug' if debug else 'info'
}
if pipe_path is not None:
args['dispatcher_pipe'] = pipe_path
......@@ -167,10 +167,10 @@ def get_stopped_experiment_config(exp_id: str, mode: str) -> None:
experiments_dict = experiments_config.get_all_experiments()
experiment_metadata = experiments_dict.get(exp_id)
if experiment_metadata is None:
logging.error('Id %s not exist!', exp_id)
_logger.error('Id %s not exist!', exp_id)
return
if experiment_metadata['status'] != 'STOPPED':
logging.error('Only stopped experiments can be %sed!', mode)
_logger.error('Only stopped experiments can be %sed!', mode)
return
experiment_config = Config(exp_id, experiment_metadata['logDir']).get_config()
config = ExperimentConfig(**experiment_config)
......
......@@ -19,7 +19,7 @@ def request(method: str, port: Optional[int], api: str, data: Any = None) -> Any
if not resp.ok:
_logger.error('rest request %s %s failed: %s %s', method.upper(), url, resp.status_code, resp.text)
resp.raise_for_status()
if method.lower() in ['get', 'post']:
if method.lower() in ['get', 'post'] and len(resp.content) > 0:
return resp.json()
def get(port: Optional[int], api: str) -> Any:
......
......@@ -72,7 +72,7 @@ class Mutable(nn.Module):
"""
After the search space is parsed, it will be the module name of the mutable.
"""
return self._name if hasattr(self, "_name") else "_key"
return self._name if hasattr(self, "_name") else self._key
@name.setter
def name(self, name):
......
......@@ -84,7 +84,7 @@ class Mutator(BaseMutator):
data = dict()
for k, v in self._cache.items():
if torch.is_tensor(v):
v = v.detach().cpu().numpy()
v = v.detach().cpu().numpy().tolist()
if isinstance(v, np.ndarray):
v = v.astype(np.float32).tolist()
data[k] = v
......
......@@ -13,7 +13,7 @@ _default_listener = None
__all__ = ['get_execution_engine', 'get_and_register_default_listener',
'list_models', 'submit_models', 'wait_models', 'query_available_resources',
'set_execution_engine', 'is_stopped_exec']
'set_execution_engine', 'is_stopped_exec', 'budget_exhausted']
def set_execution_engine(engine) -> None:
global _execution_engine
......@@ -22,6 +22,7 @@ def set_execution_engine(engine) -> None:
else:
raise RuntimeError('execution engine is already set')
def get_execution_engine() -> AbstractExecutionEngine:
"""
Currently we assume the default execution engine is BaseExecutionEngine.
......@@ -67,3 +68,8 @@ def query_available_resources() -> int:
def is_stopped_exec(model: Model) -> bool:
return model.status in (ModelStatus.Trained, ModelStatus.Failed)
def budget_exhausted() -> bool:
engine = get_execution_engine()
return engine.budget_exhausted()
......@@ -104,6 +104,10 @@ class BaseExecutionEngine(AbstractExecutionEngine):
def query_available_resource(self) -> int:
return self.resources
def budget_exhausted(self) -> bool:
advisor = get_advisor()
return advisor.stopping
@classmethod
def trial_execute_graph(cls) -> None:
"""
......
......@@ -130,6 +130,9 @@ class CGOExecutionEngine(AbstractExecutionEngine):
def query_available_resource(self) -> List[WorkerInfo]:
raise NotImplementedError # move the method from listener to here?
def budget_exhausted(self) -> bool:
raise NotImplementedError
@classmethod
def trial_execute_graph(cls) -> None:
"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment