Unverified Commit 77dac12b authored by QuanluZhang's avatar QuanluZhang Committed by GitHub
Browse files

Merge pull request #3023 from microsoft/v1.9

[do not squash!] merge v1.9 back to master
parents c2e69672 98a72a1e
docs/img/webui-img/search-trial.png

21.6 KB | W: | H:

docs/img/webui-img/search-trial.png

18 KB | W: | H:

docs/img/webui-img/search-trial.png
docs/img/webui-img/search-trial.png
docs/img/webui-img/search-trial.png
docs/img/webui-img/search-trial.png
  • 2-up
  • Swipe
  • Onion skin
docs/img/webui-img/select-trial.png

22.3 KB | W: | H:

docs/img/webui-img/select-trial.png

22.7 KB | W: | H:

docs/img/webui-img/select-trial.png
docs/img/webui-img/select-trial.png
docs/img/webui-img/select-trial.png
docs/img/webui-img/select-trial.png
  • 2-up
  • Swipe
  • Onion skin
# AMCPruner Example
This example shows us how to use AMCPruner example.
## Step 1: train a model for pruning
Run following command to train a mobilenetv2 model:
```bash
python3 amc_train.py --model_type mobilenetv2 --n_epoch 50
```
Once finished, saved checkpoint file can be found at:
```
logs/mobilenetv2_cifar10_train-run1/ckpt.best.pth
```
## Pruning with AMCPruner
Run following command to prune the trained model:
```bash
python3 amc_search.py --model_type mobilenetv2 --ckpt logs/mobilenetv2_cifar10_train-run1/ckpt.best.pth
```
Once finished, pruned model and mask can be found at:
```
logs/mobilenetv2_cifar10_r0.5_search-run2
```
## Finetune pruned model
Run `amc_train.py` again with `--ckpt` and `--mask` to speedup and finetune the pruned model:
```bash
python3 amc_train.py --model_type mobilenetv2 --ckpt logs/mobilenetv2_cifar10_r0.5_search-run2/best_model.pth --mask logs/mobilenetv2_cifar10_r0.5_search-run2/best_mask.pth --n_epoch 100
```
...@@ -20,7 +20,7 @@ def parse_args(): ...@@ -20,7 +20,7 @@ def parse_args():
help='model to prune') help='model to prune')
parser.add_argument('--dataset', default='cifar10', type=str, choices=['cifar10', 'imagenet'], help='dataset to use (cifar/imagenet)') parser.add_argument('--dataset', default='cifar10', type=str, choices=['cifar10', 'imagenet'], help='dataset to use (cifar/imagenet)')
parser.add_argument('--batch_size', default=50, type=int, help='number of data batch size') parser.add_argument('--batch_size', default=50, type=int, help='number of data batch size')
parser.add_argument('--data_root', default='./cifar10', type=str, help='dataset path') parser.add_argument('--data_root', default='./data', type=str, help='dataset path')
parser.add_argument('--flops_ratio', default=0.5, type=float, help='target flops ratio to preserve of the model') parser.add_argument('--flops_ratio', default=0.5, type=float, help='target flops ratio to preserve of the model')
parser.add_argument('--lbound', default=0.2, type=float, help='minimum sparsity') parser.add_argument('--lbound', default=0.2, type=float, help='minimum sparsity')
parser.add_argument('--rbound', default=1., type=float, help='maximum sparsity') parser.add_argument('--rbound', default=1., type=float, help='maximum sparsity')
......
...@@ -13,6 +13,7 @@ import torch ...@@ -13,6 +13,7 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.optim as optim import torch.optim as optim
from tensorboardX import SummaryWriter from tensorboardX import SummaryWriter
from torchvision.models import resnet
from nni.compression.torch.pruning.amc.lib.net_measure import measure_model from nni.compression.torch.pruning.amc.lib.net_measure import measure_model
from nni.compression.torch.pruning.amc.lib.utils import get_output_folder from nni.compression.torch.pruning.amc.lib.utils import get_output_folder
...@@ -27,7 +28,9 @@ from mobilenet_v2 import MobileNetV2 ...@@ -27,7 +28,9 @@ from mobilenet_v2 import MobileNetV2
def parse_args(): def parse_args():
parser = argparse.ArgumentParser(description='AMC train / fine-tune script') parser = argparse.ArgumentParser(description='AMC train / fine-tune script')
parser.add_argument('--model_type', default='mobilenet', type=str, help='name of the model to train') parser.add_argument('--model_type', default='mobilenet', type=str,
choices=['mobilenet', 'mobilenetv2', 'resnet18', 'resnet34', 'resnet50'],
help='name of the model to train')
parser.add_argument('--dataset', default='cifar10', type=str, help='name of the dataset to train') parser.add_argument('--dataset', default='cifar10', type=str, help='name of the dataset to train')
parser.add_argument('--lr', default=0.05, type=float, help='learning rate') parser.add_argument('--lr', default=0.05, type=float, help='learning rate')
parser.add_argument('--n_gpu', default=4, type=int, help='number of GPUs to use') parser.add_argument('--n_gpu', default=4, type=int, help='number of GPUs to use')
...@@ -62,17 +65,21 @@ def get_model(args): ...@@ -62,17 +65,21 @@ def get_model(args):
net = MobileNet(n_class=n_class) net = MobileNet(n_class=n_class)
elif args.model_type == 'mobilenetv2': elif args.model_type == 'mobilenetv2':
net = MobileNetV2(n_class=n_class) net = MobileNetV2(n_class=n_class)
elif args.model_type.startswith('resnet'):
net = resnet.__dict__[args.model_type](pretrained=True)
in_features = net.fc.in_features
net.fc = nn.Linear(in_features, n_class)
else: else:
raise NotImplementedError raise NotImplementedError
if args.ckpt_path is not None: if args.ckpt_path is not None:
# the checkpoint can be state_dict exported by amc_search.py or saved by amc_train.py # the checkpoint can be state_dict exported by amc_search.py or saved by amc_train.py
print('=> Loading checkpoint {} ..'.format(args.ckpt_path)) print('=> Loading checkpoint {} ..'.format(args.ckpt_path))
net.load_state_dict(torch.load(args.ckpt_path)) net.load_state_dict(torch.load(args.ckpt_path, torch.device('cpu')))
if args.mask_path is not None: if args.mask_path is not None:
SZ = 224 if args.dataset == 'imagenet' else 32 SZ = 224 if args.dataset == 'imagenet' else 32
data = torch.randn(2, 3, SZ, SZ) data = torch.randn(2, 3, SZ, SZ)
ms = ModelSpeedup(net, data, args.mask_path) ms = ModelSpeedup(net, data, args.mask_path, torch.device('cpu'))
ms.speedup_model() ms.speedup_model()
net.to(args.device) net.to(args.device)
...@@ -179,11 +186,11 @@ def adjust_learning_rate(optimizer, epoch): ...@@ -179,11 +186,11 @@ def adjust_learning_rate(optimizer, epoch):
return lr return lr
def save_checkpoint(state, is_best, checkpoint_dir='.'): def save_checkpoint(state, is_best, checkpoint_dir='.'):
filename = os.path.join(checkpoint_dir, 'ckpt.pth.tar') filename = os.path.join(checkpoint_dir, 'ckpt.pth')
print('=> Saving checkpoint to {}'.format(filename)) print('=> Saving checkpoint to {}'.format(filename))
torch.save(state, filename) torch.save(state, filename)
if is_best: if is_best:
shutil.copyfile(filename, filename.replace('.pth.tar', '.best.pth.tar')) shutil.copyfile(filename, filename.replace('.pth', '.best.pth'))
if __name__ == '__main__': if __name__ == '__main__':
args = parse_args() args = parse_args()
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import tensorflow as tf import tensorflow as tf
from tensorflow.keras import Model from tensorflow.keras import Model
from tensorflow.keras.layers import (AveragePooling2D, BatchNormalization, Conv2D, Dense, MaxPool2D) from tensorflow.keras.layers import (AveragePooling2D, BatchNormalization, Conv2D, Dense, MaxPool2D)
...@@ -7,8 +10,6 @@ from tensorflow.keras.optimizers import SGD ...@@ -7,8 +10,6 @@ from tensorflow.keras.optimizers import SGD
from nni.nas.tensorflow.mutables import LayerChoice, InputChoice from nni.nas.tensorflow.mutables import LayerChoice, InputChoice
from nni.nas.tensorflow.enas import EnasTrainer from nni.nas.tensorflow.enas import EnasTrainer
tf.get_logger().setLevel('ERROR')
class Net(Model): class Net(Model):
def __init__(self): def __init__(self):
...@@ -53,35 +54,36 @@ class Net(Model): ...@@ -53,35 +54,36 @@ class Net(Model):
return x return x
def accuracy(output, target): def accuracy(truth, logits):
bs = target.shape[0] truth = tf.reshape(truth, -1)
predicted = tf.cast(tf.argmax(output, 1), target.dtype) predicted = tf.cast(tf.math.argmax(logits, axis=1), truth.dtype)
target = tf.reshape(target, [-1]) equal = tf.cast(predicted == truth, tf.int32)
return sum(tf.cast(predicted == target, tf.float32)) / bs return tf.math.reduce_sum(equal).numpy() / equal.shape[0]
def accuracy_metrics(truth, logits):
acc = accuracy(truth, logits)
return {'accuracy': acc}
if __name__ == '__main__': if __name__ == '__main__':
cifar10 = tf.keras.datasets.cifar10 cifar10 = tf.keras.datasets.cifar10
(x_train, y_train), (x_test, y_test) = cifar10.load_data() (x_train, y_train), (x_valid, y_valid) = cifar10.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0 x_train, x_valid = x_train / 255.0, x_valid / 255.0
split = int(len(x_train) * 0.9) train_set = (x_train, y_train)
dataset_train = tf.data.Dataset.from_tensor_slices((x_train[:split], y_train[:split])).batch(64) valid_set = (x_valid, y_valid)
dataset_valid = tf.data.Dataset.from_tensor_slices((x_train[split:], y_train[split:])).batch(64)
dataset_test = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(64)
net = Net() net = Net()
trainer = EnasTrainer( trainer = EnasTrainer(
net, net,
loss=SparseCategoricalCrossentropy(reduction=Reduction.SUM), loss=SparseCategoricalCrossentropy(from_logits=True, reduction=Reduction.NONE),
metrics=accuracy, metrics=accuracy_metrics,
reward_function=accuracy, reward_function=accuracy,
optimizer=SGD(learning_rate=0.001, momentum=0.9), optimizer=SGD(learning_rate=0.001, momentum=0.9),
batch_size=64, batch_size=64,
num_epochs=2, num_epochs=2,
dataset_train=dataset_train, dataset_train=train_set,
dataset_valid=dataset_valid, dataset_valid=valid_set
dataset_test=dataset_test
) )
trainer.train() trainer.train()
#trainer.export('checkpoint')
...@@ -45,6 +45,7 @@ if __name__ == "__main__": ...@@ -45,6 +45,7 @@ if __name__ == "__main__":
torch.backends.cudnn.deterministic = True torch.backends.cudnn.deterministic = True
model = ShuffleNetV2OneShot() model = ShuffleNetV2OneShot()
flops_func = model.get_candidate_flops
if args.load_checkpoint: if args.load_checkpoint:
if not args.spos_preprocessing: if not args.spos_preprocessing:
logger.warning("You might want to use SPOS preprocessing if you are loading their checkpoints.") logger.warning("You might want to use SPOS preprocessing if you are loading their checkpoints.")
...@@ -52,7 +53,7 @@ if __name__ == "__main__": ...@@ -52,7 +53,7 @@ if __name__ == "__main__":
model.cuda() model.cuda()
if torch.cuda.device_count() > 1: # exclude last gpu, saving for data preprocessing on gpu if torch.cuda.device_count() > 1: # exclude last gpu, saving for data preprocessing on gpu
model = nn.DataParallel(model, device_ids=list(range(0, torch.cuda.device_count() - 1))) model = nn.DataParallel(model, device_ids=list(range(0, torch.cuda.device_count() - 1)))
mutator = SPOSSupernetTrainingMutator(model, flops_func=model.module.get_candidate_flops, mutator = SPOSSupernetTrainingMutator(model, flops_func=flops_func,
flops_lb=290E6, flops_ub=360E6) flops_lb=290E6, flops_ub=360E6)
criterion = CrossEntropyLabelSmooth(1000, args.label_smoothing) criterion = CrossEntropyLabelSmooth(1000, args.label_smoothing)
optimizer = torch.optim.SGD(model.parameters(), lr=args.learning_rate, optimizer = torch.optim.SGD(model.parameters(), lr=args.learning_rate,
......
...@@ -17,9 +17,9 @@ tuner: ...@@ -17,9 +17,9 @@ tuner:
trial: trial:
command: python3 mnist.py command: python3 mnist.py
codeDir: . codeDir: .
computeTarget: ${replace_to_your_computeTarget}
image: msranni/nni image: msranni/nni
amlConfig: amlConfig:
subscriptionId: ${replace_to_your_subscriptionId} subscriptionId: ${replace_to_your_subscriptionId}
resourceGroup: ${replace_to_your_resourceGroup} resourceGroup: ${replace_to_your_resourceGroup}
workspaceName: ${replace_to_your_workspaceName} workspaceName: ${replace_to_your_workspaceName}
computeTarget: ${replace_to_your_computeTarget}
...@@ -17,9 +17,9 @@ tuner: ...@@ -17,9 +17,9 @@ tuner:
trial: trial:
command: python3 mnist.py command: python3 mnist.py
codeDir: . codeDir: .
computeTarget: ${replace_to_your_computeTarget}
image: msranni/nni image: msranni/nni
amlConfig: amlConfig:
subscriptionId: ${replace_to_your_subscriptionId} subscriptionId: ${replace_to_your_subscriptionId}
resourceGroup: ${replace_to_your_resourceGroup} resourceGroup: ${replace_to_your_resourceGroup}
workspaceName: ${replace_to_your_workspaceName} workspaceName: ${replace_to_your_workspaceName}
computeTarget: ${replace_to_your_computeTarget}
...@@ -136,6 +136,10 @@ class LinuxCommands extends OsCommands { ...@@ -136,6 +136,10 @@ class LinuxCommands extends OsCommands {
return `${preCommand} && ${command}`; return `${preCommand} && ${command}`;
} }
} }
public fileExistCommand(filePath: string): string {
return `test -e ${filePath} && echo True || echo False`;
}
} }
export { LinuxCommands }; export { LinuxCommands };
...@@ -130,6 +130,10 @@ class WindowsCommands extends OsCommands { ...@@ -130,6 +130,10 @@ class WindowsCommands extends OsCommands {
return `${preCommand} && set prePath=%path% && ${command}`; return `${preCommand} && set prePath=%path% && ${command}`;
} }
} }
public fileExistCommand(filePath: string): string {
return `powershell Test-Path ${filePath} -PathType Leaf`;
}
} }
export { WindowsCommands }; export { WindowsCommands };
...@@ -29,6 +29,7 @@ abstract class OsCommands { ...@@ -29,6 +29,7 @@ abstract class OsCommands {
public abstract extractFile(tarFileName: string, targetFolder: string): string; public abstract extractFile(tarFileName: string, targetFolder: string): string;
public abstract executeScript(script: string, isFile: boolean): string; public abstract executeScript(script: string, isFile: boolean): string;
public abstract addPreCommand(preCommand: string | undefined, command: string | undefined): string | undefined; public abstract addPreCommand(preCommand: string | undefined, command: string | undefined): string | undefined;
public abstract fileExistCommand(filePath: string): string | undefined;
public joinPath(...paths: string[]): string { public joinPath(...paths: string[]): string {
let dir: string = paths.filter((path: any) => path !== '').join(this.pathSpliter); let dir: string = paths.filter((path: any) => path !== '').join(this.pathSpliter);
......
...@@ -238,6 +238,12 @@ class ShellExecutor { ...@@ -238,6 +238,12 @@ class ShellExecutor {
return commandResult.exitCode == 0; return commandResult.exitCode == 0;
} }
public async fileExist(filePath: string): Promise<boolean> {
const commandText = this.osCommands && this.osCommands.fileExistCommand(filePath);
const commandResult = await this.execute(commandText);
return commandResult.stdout !== undefined && commandResult.stdout.trim() === 'True';
}
public async extractFile(tarFileName: string, targetFolder: string): Promise<boolean> { public async extractFile(tarFileName: string, targetFolder: string): Promise<boolean> {
const commandText = this.osCommands && this.osCommands.extractFile(tarFileName, targetFolder); const commandText = this.osCommands && this.osCommands.extractFile(tarFileName, targetFolder);
const commandResult = await this.execute(commandText); const commandResult = await this.execute(commandText);
......
...@@ -139,10 +139,15 @@ export class RemoteEnvironmentService extends EnvironmentService { ...@@ -139,10 +139,15 @@ export class RemoteEnvironmentService extends EnvironmentService {
const executor = await this.getExecutor(environment.id); const executor = await this.getExecutor(environment.id);
const jobpidPath: string = `${environment.runnerWorkingFolder}/pid`; const jobpidPath: string = `${environment.runnerWorkingFolder}/pid`;
const runnerReturnCodeFilePath: string = `${environment.runnerWorkingFolder}/code`; const runnerReturnCodeFilePath: string = `${environment.runnerWorkingFolder}/code`;
if (fs.existsSync(jobpidPath)) {
/* eslint-disable require-atomic-updates */ /* eslint-disable require-atomic-updates */
try { try {
// check if pid file exist
const pidExist = await executor.fileExist(jobpidPath);
if (!pidExist) {
return;
}
const isAlive = await executor.isProcessAlive(jobpidPath); const isAlive = await executor.isProcessAlive(jobpidPath);
environment.status = 'RUNNING';
// if the process of jobpid is not alive any more // if the process of jobpid is not alive any more
if (!isAlive) { if (!isAlive) {
const remoteEnvironment: RemoteMachineEnvironmentInformation = environment as RemoteMachineEnvironmentInformation; const remoteEnvironment: RemoteMachineEnvironmentInformation = environment as RemoteMachineEnvironmentInformation;
...@@ -167,11 +172,9 @@ export class RemoteEnvironmentService extends EnvironmentService { ...@@ -167,11 +172,9 @@ export class RemoteEnvironmentService extends EnvironmentService {
} }
} }
} catch (error) { } catch (error) {
this.releaseEnvironmentResource(environment);
this.log.error(`Update job status exception, error is ${error.message}`); this.log.error(`Update job status exception, error is ${error.message}`);
} }
} }
}
public async refreshEnvironmentsStatus(environments: EnvironmentInformation[]): Promise<void> { public async refreshEnvironmentsStatus(environments: EnvironmentInformation[]): Promise<void> {
const tasks: Promise<void>[] = []; const tasks: Promise<void>[] = [];
...@@ -245,6 +248,7 @@ export class RemoteEnvironmentService extends EnvironmentService { ...@@ -245,6 +248,7 @@ export class RemoteEnvironmentService extends EnvironmentService {
'envs', environment.id) 'envs', environment.id)
environment.command = `cd ${environment.runnerWorkingFolder} && \ environment.command = `cd ${environment.runnerWorkingFolder} && \
${environment.command} --job_pid_file ${environment.runnerWorkingFolder}/pid \ ${environment.command} --job_pid_file ${environment.runnerWorkingFolder}/pid \
1>${environment.runnerWorkingFolder}/trialrunner_stdout 2>${environment.runnerWorkingFolder}/trialrunner_stderr \
&& echo $? \`date +%s%3N\` >${environment.runnerWorkingFolder}/code`; && echo $? \`date +%s%3N\` >${environment.runnerWorkingFolder}/code`;
return Promise.resolve(true); return Promise.resolve(true);
} }
...@@ -266,7 +270,6 @@ ${environment.command} --job_pid_file ${environment.runnerWorkingFolder}/pid \ ...@@ -266,7 +270,6 @@ ${environment.command} --job_pid_file ${environment.runnerWorkingFolder}/pid \
// Execute command in remote machine // Execute command in remote machine
executor.executeScript(executor.joinPath(environment.runnerWorkingFolder, executor.executeScript(executor.joinPath(environment.runnerWorkingFolder,
executor.getScriptName("run")), true, false); executor.getScriptName("run")), true, false);
environment.status = 'RUNNING';
if (environment.rmMachineMeta === undefined) { if (environment.rmMachineMeta === undefined) {
throw new Error(`${environment.id} rmMachineMeta not initialized!`); throw new Error(`${environment.id} rmMachineMeta not initialized!`);
} }
......
...@@ -663,18 +663,21 @@ class TrialDispatcher implements TrainingService { ...@@ -663,18 +663,21 @@ class TrialDispatcher implements TrainingService {
await this.commandChannel.sendCommand(trial.environment, NEW_TRIAL_JOB, trial.settings); await this.commandChannel.sendCommand(trial.environment, NEW_TRIAL_JOB, trial.settings);
} }
/**
* release the trial assigned environment resources
* @param trial
*/
private releaseEnvironment(trial: TrialDetail): void { private releaseEnvironment(trial: TrialDetail): void {
if (undefined === trial.environment) { if (trial.environment !== undefined) {
throw new Error(`TrialDispatcher: environment is not assigned to trial ${trial.id}, and cannot be released!`);
}
if (trial.environment.runningTrialCount <= 0) { if (trial.environment.runningTrialCount <= 0) {
throw new Error(`TrialDispatcher: environment ${trial.environment.id} has no counted running trial!`); throw new Error(`TrialDispatcher: environment ${trial.environment.id} has no counted running trial!`);
} }
trial.environment.runningTrialCount--;
trial.environment = undefined;
}
if (true === this.enableGpuScheduler) { if (true === this.enableGpuScheduler) {
this.gpuScheduler.removeGpuReservation(trial); this.gpuScheduler.removeGpuReservation(trial);
} }
trial.environment.runningTrialCount--;
trial.environment = undefined;
} }
private async handleMetricData(trialId: string, data: any): Promise<void> { private async handleMetricData(trialId: string, data: any): Promise<void> {
......
...@@ -1089,10 +1089,10 @@ cli-width@^2.0.0: ...@@ -1089,10 +1089,10 @@ cli-width@^2.0.0:
version "2.2.0" version "2.2.0"
resolved "https://registry.yarnpkg.com/cli-width/-/cli-width-2.2.0.tgz#ff19ede8a9a5e579324147b0c11f0fbcbabed639" resolved "https://registry.yarnpkg.com/cli-width/-/cli-width-2.2.0.tgz#ff19ede8a9a5e579324147b0c11f0fbcbabed639"
cliui@^7.0.0: cliui@^7.0.2:
version "7.0.1" version "7.0.3"
resolved "https://registry.yarnpkg.com/cliui/-/cliui-7.0.1.tgz#a4cb67aad45cd83d8d05128fc9f4d8fbb887e6b3" resolved "https://registry.yarnpkg.com/cliui/-/cliui-7.0.3.tgz#ef180f26c8d9bff3927ee52428bfec2090427981"
integrity sha512-rcvHOWyGyid6I1WjT/3NatKj2kDt9OdSHSXpyLXaMWFbKpGACNW8pRhhdPUq9MWUOdwn8Rz9AVETjF4105rZZQ== integrity sha512-Gj3QHTkVMPKqwP3f7B4KPkBZRMR9r4rfi5bXFpg1a+Svvj8l7q5CnkBkVQzfxT5DFSsGk2+PascOgL0JYkL2kw==
dependencies: dependencies:
string-width "^4.2.0" string-width "^4.2.0"
strip-ansi "^6.0.0" strip-ansi "^6.0.0"
...@@ -1329,7 +1329,7 @@ debug@^3.1.0: ...@@ -1329,7 +1329,7 @@ debug@^3.1.0:
dependencies: dependencies:
ms "^2.1.1" ms "^2.1.1"
debuglog@^1.0.1: debuglog@*, debuglog@^1.0.1:
version "1.0.1" version "1.0.1"
resolved "https://registry.yarnpkg.com/debuglog/-/debuglog-1.0.1.tgz#aa24ffb9ac3df9a2351837cfb2d279360cd78492" resolved "https://registry.yarnpkg.com/debuglog/-/debuglog-1.0.1.tgz#aa24ffb9ac3df9a2351837cfb2d279360cd78492"
...@@ -1599,10 +1599,10 @@ es6-promisify@^5.0.0: ...@@ -1599,10 +1599,10 @@ es6-promisify@^5.0.0:
dependencies: dependencies:
es6-promise "^4.0.3" es6-promise "^4.0.3"
escalade@^3.0.2: escalade@^3.1.1:
version "3.1.0" version "3.1.1"
resolved "https://registry.yarnpkg.com/escalade/-/escalade-3.1.0.tgz#e8e2d7c7a8b76f6ee64c2181d6b8151441602d4e" resolved "https://registry.yarnpkg.com/escalade/-/escalade-3.1.1.tgz#d8cfdc7000965c5a0174b4a82eaa5c0552742e40"
integrity sha512-mAk+hPSO8fLDkhV7V0dXazH5pDc6MrjBTPyD3VeKzxnVFjH1MIxbCdqGZB9O8+EwWakZs3ZCbDS4IpRt79V1ig== integrity sha512-k0er2gUkLf8O0zKJiAhmkTnJlTvINGv7ygDNPbeIsX/TJjGJZHuh9B2UxbsaEkmlEo9MfhrSzmhIlhRlI2GXnw==
escape-html@~1.0.3: escape-html@~1.0.3:
version "1.0.3" version "1.0.3"
...@@ -2375,7 +2375,7 @@ import-lazy@^2.1.0: ...@@ -2375,7 +2375,7 @@ import-lazy@^2.1.0:
version "2.1.0" version "2.1.0"
resolved "https://registry.yarnpkg.com/import-lazy/-/import-lazy-2.1.0.tgz#05698e3d45c88e8d7e9d92cb0584e77f096f3e43" resolved "https://registry.yarnpkg.com/import-lazy/-/import-lazy-2.1.0.tgz#05698e3d45c88e8d7e9d92cb0584e77f096f3e43"
imurmurhash@^0.1.4: imurmurhash@*, imurmurhash@^0.1.4:
version "0.1.4" version "0.1.4"
resolved "https://registry.yarnpkg.com/imurmurhash/-/imurmurhash-0.1.4.tgz#9218b9b2b928a238b13dc4fb6b6d576f231453ea" resolved "https://registry.yarnpkg.com/imurmurhash/-/imurmurhash-0.1.4.tgz#9218b9b2b928a238b13dc4fb6b6d576f231453ea"
integrity sha1-khi5srkoojixPcT7a21XbyMUU+o= integrity sha1-khi5srkoojixPcT7a21XbyMUU+o=
...@@ -3057,6 +3057,11 @@ lockfile@^1.0.4: ...@@ -3057,6 +3057,11 @@ lockfile@^1.0.4:
dependencies: dependencies:
signal-exit "^3.0.2" signal-exit "^3.0.2"
lodash._baseindexof@*:
version "3.1.0"
resolved "https://registry.yarnpkg.com/lodash._baseindexof/-/lodash._baseindexof-3.1.0.tgz#fe52b53a1c6761e42618d654e4a25789ed61822c"
integrity sha1-/lK1OhxnYeQmGNZU5KJXie1hgiw=
lodash._baseuniq@~4.6.0: lodash._baseuniq@~4.6.0:
version "4.6.0" version "4.6.0"
resolved "https://registry.yarnpkg.com/lodash._baseuniq/-/lodash._baseuniq-4.6.0.tgz#0ebb44e456814af7905c6212fa2c9b2d51b841e8" resolved "https://registry.yarnpkg.com/lodash._baseuniq/-/lodash._baseuniq-4.6.0.tgz#0ebb44e456814af7905c6212fa2c9b2d51b841e8"
...@@ -3064,10 +3069,32 @@ lodash._baseuniq@~4.6.0: ...@@ -3064,10 +3069,32 @@ lodash._baseuniq@~4.6.0:
lodash._createset "~4.0.0" lodash._createset "~4.0.0"
lodash._root "~3.0.0" lodash._root "~3.0.0"
lodash._bindcallback@*:
version "3.0.1"
resolved "https://registry.yarnpkg.com/lodash._bindcallback/-/lodash._bindcallback-3.0.1.tgz#e531c27644cf8b57a99e17ed95b35c748789392e"
integrity sha1-5THCdkTPi1epnhftlbNcdIeJOS4=
lodash._cacheindexof@*:
version "3.0.2"
resolved "https://registry.yarnpkg.com/lodash._cacheindexof/-/lodash._cacheindexof-3.0.2.tgz#3dc69ac82498d2ee5e3ce56091bafd2adc7bde92"
integrity sha1-PcaayCSY0u5ePOVgkbr9Ktx73pI=
lodash._createcache@*:
version "3.1.2"
resolved "https://registry.yarnpkg.com/lodash._createcache/-/lodash._createcache-3.1.2.tgz#56d6a064017625e79ebca6b8018e17440bdcf093"
integrity sha1-VtagZAF2JeeevKa4AY4XRAvc8JM=
dependencies:
lodash._getnative "^3.0.0"
lodash._createset@~4.0.0: lodash._createset@~4.0.0:
version "4.0.3" version "4.0.3"
resolved "https://registry.yarnpkg.com/lodash._createset/-/lodash._createset-4.0.3.tgz#0f4659fbb09d75194fa9e2b88a6644d363c9fe26" resolved "https://registry.yarnpkg.com/lodash._createset/-/lodash._createset-4.0.3.tgz#0f4659fbb09d75194fa9e2b88a6644d363c9fe26"
lodash._getnative@*, lodash._getnative@^3.0.0:
version "3.9.1"
resolved "https://registry.yarnpkg.com/lodash._getnative/-/lodash._getnative-3.9.1.tgz#570bc7dede46d61cdcde687d65d3eecbaa3aaff5"
integrity sha1-VwvH3t5G1hzc3mh9ZdPuy6o6r/U=
lodash._root@~3.0.0: lodash._root@~3.0.0:
version "3.0.1" version "3.0.1"
resolved "https://registry.yarnpkg.com/lodash._root/-/lodash._root-3.0.1.tgz#fba1c4524c19ee9a5f8136b4609f017cf4ded692" resolved "https://registry.yarnpkg.com/lodash._root/-/lodash._root-3.0.1.tgz#fba1c4524c19ee9a5f8136b4609f017cf4ded692"
...@@ -3116,6 +3143,11 @@ lodash.pick@^4.4.0: ...@@ -3116,6 +3143,11 @@ lodash.pick@^4.4.0:
version "4.4.0" version "4.4.0"
resolved "https://registry.yarnpkg.com/lodash.pick/-/lodash.pick-4.4.0.tgz#52f05610fff9ded422611441ed1fc123a03001b3" resolved "https://registry.yarnpkg.com/lodash.pick/-/lodash.pick-4.4.0.tgz#52f05610fff9ded422611441ed1fc123a03001b3"
lodash.restparam@*:
version "3.6.1"
resolved "https://registry.yarnpkg.com/lodash.restparam/-/lodash.restparam-3.6.1.tgz#936a4e309ef330a7645ed4145986c85ae5b20805"
integrity sha1-k2pOMJ7zMKdkXtQUWYbIWuWyCAU=
lodash.unescape@4.0.1: lodash.unescape@4.0.1:
version "4.0.1" version "4.0.1"
resolved "https://registry.yarnpkg.com/lodash.unescape/-/lodash.unescape-4.0.1.tgz#bf2249886ce514cda112fae9218cdc065211fc9c" resolved "https://registry.yarnpkg.com/lodash.unescape/-/lodash.unescape-4.0.1.tgz#bf2249886ce514cda112fae9218cdc065211fc9c"
...@@ -3679,10 +3711,11 @@ npm-run-path@^2.0.0: ...@@ -3679,10 +3711,11 @@ npm-run-path@^2.0.0:
path-key "^2.0.0" path-key "^2.0.0"
npm-user-validate@~1.0.0: npm-user-validate@~1.0.0:
version "1.0.0" version "1.0.1"
resolved "https://registry.yarnpkg.com/npm-user-validate/-/npm-user-validate-1.0.0.tgz#8ceca0f5cea04d4e93519ef72d0557a75122e951" resolved "https://registry.yarnpkg.com/npm-user-validate/-/npm-user-validate-1.0.1.tgz#31428fc5475fe8416023f178c0ab47935ad8c561"
integrity sha512-uQwcd/tY+h1jnEaze6cdX/LrhWhoBxfSknxentoqmIuStxUExxjWd3ULMLFPiFUrZKbOVMowH6Jq2FRWfmhcEw==
npm@5.1.0, npm@^6.14.8: npm@5.1.0, npm@>=6.14.8:
version "6.14.8" version "6.14.8"
resolved "https://registry.yarnpkg.com/npm/-/npm-6.14.8.tgz#64ef754345639bc035982ec3f609353c8539033c" resolved "https://registry.yarnpkg.com/npm/-/npm-6.14.8.tgz#64ef754345639bc035982ec3f609353c8539033c"
integrity sha512-HBZVBMYs5blsj94GTeQZel7s9odVuuSUHy1+AlZh7rPVux1os2ashvEGLy/STNK7vUjbrCg5Kq9/GXisJgdf6A== integrity sha512-HBZVBMYs5blsj94GTeQZel7s9odVuuSUHy1+AlZh7rPVux1os2ashvEGLy/STNK7vUjbrCg5Kq9/GXisJgdf6A==
...@@ -3705,6 +3738,7 @@ npm@5.1.0, npm@^6.14.8: ...@@ -3705,6 +3738,7 @@ npm@5.1.0, npm@^6.14.8:
cmd-shim "^3.0.3" cmd-shim "^3.0.3"
columnify "~1.5.4" columnify "~1.5.4"
config-chain "^1.1.12" config-chain "^1.1.12"
debuglog "*"
detect-indent "~5.0.0" detect-indent "~5.0.0"
detect-newline "^2.1.0" detect-newline "^2.1.0"
dezalgo "~1.0.3" dezalgo "~1.0.3"
...@@ -3719,6 +3753,7 @@ npm@5.1.0, npm@^6.14.8: ...@@ -3719,6 +3753,7 @@ npm@5.1.0, npm@^6.14.8:
has-unicode "~2.0.1" has-unicode "~2.0.1"
hosted-git-info "^2.8.8" hosted-git-info "^2.8.8"
iferr "^1.0.2" iferr "^1.0.2"
imurmurhash "*"
infer-owner "^1.0.4" infer-owner "^1.0.4"
inflight "~1.0.6" inflight "~1.0.6"
inherits "^2.0.4" inherits "^2.0.4"
...@@ -3737,8 +3772,14 @@ npm@5.1.0, npm@^6.14.8: ...@@ -3737,8 +3772,14 @@ npm@5.1.0, npm@^6.14.8:
libnpx "^10.2.4" libnpx "^10.2.4"
lock-verify "^2.1.0" lock-verify "^2.1.0"
lockfile "^1.0.4" lockfile "^1.0.4"
lodash._baseindexof "*"
lodash._baseuniq "~4.6.0" lodash._baseuniq "~4.6.0"
lodash._bindcallback "*"
lodash._cacheindexof "*"
lodash._createcache "*"
lodash._getnative "*"
lodash.clonedeep "~4.5.0" lodash.clonedeep "~4.5.0"
lodash.restparam "*"
lodash.union "~4.6.0" lodash.union "~4.6.0"
lodash.uniq "~4.5.0" lodash.uniq "~4.5.0"
lodash.without "~4.4.0" lodash.without "~4.4.0"
...@@ -5664,10 +5705,10 @@ y18n@^4.0.0: ...@@ -5664,10 +5705,10 @@ y18n@^4.0.0:
resolved "https://registry.yarnpkg.com/y18n/-/y18n-4.0.0.tgz#95ef94f85ecc81d007c264e190a120f0a3c8566b" resolved "https://registry.yarnpkg.com/y18n/-/y18n-4.0.0.tgz#95ef94f85ecc81d007c264e190a120f0a3c8566b"
integrity sha512-r9S/ZyXu/Xu9q1tYlpsLIsa3EeLXXk0VwlxqTcFRfg9EhMW+17kbt9G0NrgCmhGb5vT2hyhJZLfDGx+7+5Uj/w== integrity sha512-r9S/ZyXu/Xu9q1tYlpsLIsa3EeLXXk0VwlxqTcFRfg9EhMW+17kbt9G0NrgCmhGb5vT2hyhJZLfDGx+7+5Uj/w==
y18n@^5.0.1: y18n@^5.0.2:
version "5.0.1" version "5.0.4"
resolved "https://registry.yarnpkg.com/y18n/-/y18n-5.0.1.tgz#1ad2a7eddfa8bce7caa2e1f6b5da96c39d99d571" resolved "https://registry.yarnpkg.com/y18n/-/y18n-5.0.4.tgz#0ab2db89dd5873b5ec4682d8e703e833373ea897"
integrity sha512-/jJ831jEs4vGDbYPQp4yGKDYPSCCEQ45uZWJHE1AoYBzqdZi8+LDWas0z4HrmJXmKdpFsTiowSHXdxyFhpmdMg== integrity sha512-deLOfD+RvFgrpAmSZgfGdWYE+OKyHcVHaRQ7NphG/63scpRvTHHeQMAxGGvaLVGJ+HYVcCXlzcTK0ZehFf+eHQ==
yallist@^2.1.2: yallist@^2.1.2:
version "2.1.2" version "2.1.2"
...@@ -5686,10 +5727,10 @@ yallist@^4.0.0: ...@@ -5686,10 +5727,10 @@ yallist@^4.0.0:
version "4.0.0" version "4.0.0"
resolved "https://registry.yarnpkg.com/yallist/-/yallist-4.0.0.tgz#9bb92790d9c0effec63be73519e11a35019a3a72" resolved "https://registry.yarnpkg.com/yallist/-/yallist-4.0.0.tgz#9bb92790d9c0effec63be73519e11a35019a3a72"
yargs-parser@13.1.2, yargs-parser@^20.0.0, yargs-parser@^20.2.0: yargs-parser@13.1.2, yargs-parser@>=20.2.0, yargs-parser@^20.2.2:
version "20.2.0" version "20.2.3"
resolved "https://registry.yarnpkg.com/yargs-parser/-/yargs-parser-20.2.0.tgz#944791ca2be2e08ddadd3d87e9de4c6484338605" resolved "https://registry.yarnpkg.com/yargs-parser/-/yargs-parser-20.2.3.tgz#92419ba867b858c868acf8bae9bf74af0dd0ce26"
integrity sha512-2agPoRFPoIcFzOIp6656gcvsg2ohtscpw2OINr/q46+Sq41xz2OYLqx5HRHabmFU1OARIPAYH5uteICE7mn/5A== integrity sha512-emOFRT9WVHw03QSvN5qor9QQT9+sw5vwxfYweivSMHTcAXPefwVae2FjO7JJjj8hCE4CzPOPeFM83VwT29HCww==
yargs-unparser@1.6.1: yargs-unparser@1.6.1:
version "1.6.1" version "1.6.1"
...@@ -5702,18 +5743,18 @@ yargs-unparser@1.6.1: ...@@ -5702,18 +5743,18 @@ yargs-unparser@1.6.1:
is-plain-obj "^1.1.0" is-plain-obj "^1.1.0"
yargs "^14.2.3" yargs "^14.2.3"
yargs@13.3.2, yargs@^11.0.0, yargs@^14.2.3, yargs@^15.0.2, yargs@^16.0.3, yargs@^8.0.2: yargs@13.3.2, yargs@>=16.0.3, yargs@^11.0.0, yargs@^14.2.3, yargs@^15.0.2, yargs@^8.0.2:
version "16.0.3" version "16.1.0"
resolved "https://registry.yarnpkg.com/yargs/-/yargs-16.0.3.tgz#7a919b9e43c90f80d4a142a89795e85399a7e54c" resolved "https://registry.yarnpkg.com/yargs/-/yargs-16.1.0.tgz#fc333fe4791660eace5a894b39d42f851cd48f2a"
integrity sha512-6+nLw8xa9uK1BOEOykaiYAJVh6/CjxWXK/q9b5FpRgNslt8s22F2xMBqVIKgCRjNgGvGPBy8Vog7WN7yh4amtA== integrity sha512-upWFJOmDdHN0syLuESuvXDmrRcWd1QafJolHskzaw79uZa7/x53gxQKiR07W59GWY1tFhhU/Th9DrtSfpS782g==
dependencies: dependencies:
cliui "^7.0.0" cliui "^7.0.2"
escalade "^3.0.2" escalade "^3.1.1"
get-caller-file "^2.0.5" get-caller-file "^2.0.5"
require-directory "^2.1.1" require-directory "^2.1.1"
string-width "^4.2.0" string-width "^4.2.0"
y18n "^5.0.1" y18n "^5.0.2"
yargs-parser "^20.0.0" yargs-parser "^20.2.2"
yn@^2.0.0: yn@^2.0.0:
version "2.0.0" version "2.0.0"
......
...@@ -17,6 +17,11 @@ log_level_map = { ...@@ -17,6 +17,11 @@ log_level_map = {
_time_format = '%m/%d/%Y, %I:%M:%S %p' _time_format = '%m/%d/%Y, %I:%M:%S %p'
# FIXME
# This hotfix the bug that querying installed tuners with `package_utils` will activate dispatcher logger.
# This behavior depends on underlying implementation of `nnictl` and is likely to break in future.
_logger_initialized = False
class _LoggerFileWrapper(TextIOBase): class _LoggerFileWrapper(TextIOBase):
def __init__(self, logger_file): def __init__(self, logger_file):
self.file = logger_file self.file = logger_file
...@@ -33,6 +38,11 @@ def init_logger(logger_file_path, log_level_name='info'): ...@@ -33,6 +38,11 @@ def init_logger(logger_file_path, log_level_name='info'):
This will redirect anything from logging.getLogger() as well as stdout to specified file. This will redirect anything from logging.getLogger() as well as stdout to specified file.
logger_file_path: path of logger file (path-like object). logger_file_path: path of logger file (path-like object).
""" """
global _logger_initialized
if _logger_initialized:
return
_logger_initialized = True
log_level = log_level_map.get(log_level_name, logging.INFO) log_level = log_level_map.get(log_level_name, logging.INFO)
logger_file = open(logger_file_path, 'w') logger_file = open(logger_file_path, 'w')
fmt = '[%(asctime)s] %(levelname)s (%(name)s/%(threadName)s) %(message)s' fmt = '[%(asctime)s] %(levelname)s (%(name)s/%(threadName)s) %(message)s'
...@@ -55,6 +65,11 @@ def init_standalone_logger(): ...@@ -55,6 +65,11 @@ def init_standalone_logger():
Initialize root logger for standalone mode. Initialize root logger for standalone mode.
This will set NNI's log level to INFO and print its log to stdout. This will set NNI's log level to INFO and print its log to stdout.
""" """
global _logger_initialized
if _logger_initialized:
return
_logger_initialized = True
fmt = '[%(asctime)s] %(levelname)s (%(name)s) %(message)s' fmt = '[%(asctime)s] %(levelname)s (%(name)s) %(message)s'
formatter = logging.Formatter(fmt, _time_format) formatter = logging.Formatter(fmt, _time_format)
handler = logging.StreamHandler(sys.stdout) handler = logging.StreamHandler(sys.stdout)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment