Unverified Commit 77dac12b authored by QuanluZhang's avatar QuanluZhang Committed by GitHub
Browse files

Merge pull request #3023 from microsoft/v1.9

[do not squash!] merge v1.9 back to master
parents c2e69672 98a72a1e
docs/img/webui-img/search-trial.png

21.6 KB | W: | H:

docs/img/webui-img/search-trial.png

18 KB | W: | H:

docs/img/webui-img/search-trial.png
docs/img/webui-img/search-trial.png
docs/img/webui-img/search-trial.png
docs/img/webui-img/search-trial.png
  • 2-up
  • Swipe
  • Onion skin
docs/img/webui-img/select-trial.png

22.3 KB | W: | H:

docs/img/webui-img/select-trial.png

22.7 KB | W: | H:

docs/img/webui-img/select-trial.png
docs/img/webui-img/select-trial.png
docs/img/webui-img/select-trial.png
docs/img/webui-img/select-trial.png
  • 2-up
  • Swipe
  • Onion skin
# AMCPruner Example
This example shows us how to use AMCPruner example.
## Step 1: train a model for pruning
Run following command to train a mobilenetv2 model:
```bash
python3 amc_train.py --model_type mobilenetv2 --n_epoch 50
```
Once finished, saved checkpoint file can be found at:
```
logs/mobilenetv2_cifar10_train-run1/ckpt.best.pth
```
## Pruning with AMCPruner
Run following command to prune the trained model:
```bash
python3 amc_search.py --model_type mobilenetv2 --ckpt logs/mobilenetv2_cifar10_train-run1/ckpt.best.pth
```
Once finished, pruned model and mask can be found at:
```
logs/mobilenetv2_cifar10_r0.5_search-run2
```
## Finetune pruned model
Run `amc_train.py` again with `--ckpt` and `--mask` to speedup and finetune the pruned model:
```bash
python3 amc_train.py --model_type mobilenetv2 --ckpt logs/mobilenetv2_cifar10_r0.5_search-run2/best_model.pth --mask logs/mobilenetv2_cifar10_r0.5_search-run2/best_mask.pth --n_epoch 100
```
......@@ -20,7 +20,7 @@ def parse_args():
help='model to prune')
parser.add_argument('--dataset', default='cifar10', type=str, choices=['cifar10', 'imagenet'], help='dataset to use (cifar/imagenet)')
parser.add_argument('--batch_size', default=50, type=int, help='number of data batch size')
parser.add_argument('--data_root', default='./cifar10', type=str, help='dataset path')
parser.add_argument('--data_root', default='./data', type=str, help='dataset path')
parser.add_argument('--flops_ratio', default=0.5, type=float, help='target flops ratio to preserve of the model')
parser.add_argument('--lbound', default=0.2, type=float, help='minimum sparsity')
parser.add_argument('--rbound', default=1., type=float, help='maximum sparsity')
......
......@@ -13,6 +13,7 @@ import torch
import torch.nn as nn
import torch.optim as optim
from tensorboardX import SummaryWriter
from torchvision.models import resnet
from nni.compression.torch.pruning.amc.lib.net_measure import measure_model
from nni.compression.torch.pruning.amc.lib.utils import get_output_folder
......@@ -27,7 +28,9 @@ from mobilenet_v2 import MobileNetV2
def parse_args():
parser = argparse.ArgumentParser(description='AMC train / fine-tune script')
parser.add_argument('--model_type', default='mobilenet', type=str, help='name of the model to train')
parser.add_argument('--model_type', default='mobilenet', type=str,
choices=['mobilenet', 'mobilenetv2', 'resnet18', 'resnet34', 'resnet50'],
help='name of the model to train')
parser.add_argument('--dataset', default='cifar10', type=str, help='name of the dataset to train')
parser.add_argument('--lr', default=0.05, type=float, help='learning rate')
parser.add_argument('--n_gpu', default=4, type=int, help='number of GPUs to use')
......@@ -62,17 +65,21 @@ def get_model(args):
net = MobileNet(n_class=n_class)
elif args.model_type == 'mobilenetv2':
net = MobileNetV2(n_class=n_class)
elif args.model_type.startswith('resnet'):
net = resnet.__dict__[args.model_type](pretrained=True)
in_features = net.fc.in_features
net.fc = nn.Linear(in_features, n_class)
else:
raise NotImplementedError
if args.ckpt_path is not None:
# the checkpoint can be state_dict exported by amc_search.py or saved by amc_train.py
print('=> Loading checkpoint {} ..'.format(args.ckpt_path))
net.load_state_dict(torch.load(args.ckpt_path))
net.load_state_dict(torch.load(args.ckpt_path, torch.device('cpu')))
if args.mask_path is not None:
SZ = 224 if args.dataset == 'imagenet' else 32
data = torch.randn(2, 3, SZ, SZ)
ms = ModelSpeedup(net, data, args.mask_path)
ms = ModelSpeedup(net, data, args.mask_path, torch.device('cpu'))
ms.speedup_model()
net.to(args.device)
......@@ -179,11 +186,11 @@ def adjust_learning_rate(optimizer, epoch):
return lr
def save_checkpoint(state, is_best, checkpoint_dir='.'):
filename = os.path.join(checkpoint_dir, 'ckpt.pth.tar')
filename = os.path.join(checkpoint_dir, 'ckpt.pth')
print('=> Saving checkpoint to {}'.format(filename))
torch.save(state, filename)
if is_best:
shutil.copyfile(filename, filename.replace('.pth.tar', '.best.pth.tar'))
shutil.copyfile(filename, filename.replace('.pth', '.best.pth'))
if __name__ == '__main__':
args = parse_args()
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import tensorflow as tf
from tensorflow.keras import Model
from tensorflow.keras.layers import (AveragePooling2D, BatchNormalization, Conv2D, Dense, MaxPool2D)
......@@ -7,8 +10,6 @@ from tensorflow.keras.optimizers import SGD
from nni.nas.tensorflow.mutables import LayerChoice, InputChoice
from nni.nas.tensorflow.enas import EnasTrainer
tf.get_logger().setLevel('ERROR')
class Net(Model):
def __init__(self):
......@@ -53,35 +54,36 @@ class Net(Model):
return x
def accuracy(output, target):
bs = target.shape[0]
predicted = tf.cast(tf.argmax(output, 1), target.dtype)
target = tf.reshape(target, [-1])
return sum(tf.cast(predicted == target, tf.float32)) / bs
def accuracy(truth, logits):
truth = tf.reshape(truth, -1)
predicted = tf.cast(tf.math.argmax(logits, axis=1), truth.dtype)
equal = tf.cast(predicted == truth, tf.int32)
return tf.math.reduce_sum(equal).numpy() / equal.shape[0]
def accuracy_metrics(truth, logits):
acc = accuracy(truth, logits)
return {'accuracy': acc}
if __name__ == '__main__':
cifar10 = tf.keras.datasets.cifar10
(x_train, y_train), (x_test, y_test) = cifar10.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
split = int(len(x_train) * 0.9)
dataset_train = tf.data.Dataset.from_tensor_slices((x_train[:split], y_train[:split])).batch(64)
dataset_valid = tf.data.Dataset.from_tensor_slices((x_train[split:], y_train[split:])).batch(64)
dataset_test = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(64)
(x_train, y_train), (x_valid, y_valid) = cifar10.load_data()
x_train, x_valid = x_train / 255.0, x_valid / 255.0
train_set = (x_train, y_train)
valid_set = (x_valid, y_valid)
net = Net()
trainer = EnasTrainer(
net,
loss=SparseCategoricalCrossentropy(reduction=Reduction.SUM),
metrics=accuracy,
loss=SparseCategoricalCrossentropy(from_logits=True, reduction=Reduction.NONE),
metrics=accuracy_metrics,
reward_function=accuracy,
optimizer=SGD(learning_rate=0.001, momentum=0.9),
batch_size=64,
num_epochs=2,
dataset_train=dataset_train,
dataset_valid=dataset_valid,
dataset_test=dataset_test
dataset_train=train_set,
dataset_valid=valid_set
)
trainer.train()
#trainer.export('checkpoint')
......@@ -45,6 +45,7 @@ if __name__ == "__main__":
torch.backends.cudnn.deterministic = True
model = ShuffleNetV2OneShot()
flops_func = model.get_candidate_flops
if args.load_checkpoint:
if not args.spos_preprocessing:
logger.warning("You might want to use SPOS preprocessing if you are loading their checkpoints.")
......@@ -52,7 +53,7 @@ if __name__ == "__main__":
model.cuda()
if torch.cuda.device_count() > 1: # exclude last gpu, saving for data preprocessing on gpu
model = nn.DataParallel(model, device_ids=list(range(0, torch.cuda.device_count() - 1)))
mutator = SPOSSupernetTrainingMutator(model, flops_func=model.module.get_candidate_flops,
mutator = SPOSSupernetTrainingMutator(model, flops_func=flops_func,
flops_lb=290E6, flops_ub=360E6)
criterion = CrossEntropyLabelSmooth(1000, args.label_smoothing)
optimizer = torch.optim.SGD(model.parameters(), lr=args.learning_rate,
......
......@@ -17,9 +17,9 @@ tuner:
trial:
command: python3 mnist.py
codeDir: .
computeTarget: ${replace_to_your_computeTarget}
image: msranni/nni
amlConfig:
subscriptionId: ${replace_to_your_subscriptionId}
resourceGroup: ${replace_to_your_resourceGroup}
workspaceName: ${replace_to_your_workspaceName}
computeTarget: ${replace_to_your_computeTarget}
......@@ -17,9 +17,9 @@ tuner:
trial:
command: python3 mnist.py
codeDir: .
computeTarget: ${replace_to_your_computeTarget}
image: msranni/nni
amlConfig:
subscriptionId: ${replace_to_your_subscriptionId}
resourceGroup: ${replace_to_your_resourceGroup}
workspaceName: ${replace_to_your_workspaceName}
computeTarget: ${replace_to_your_computeTarget}
......@@ -136,6 +136,10 @@ class LinuxCommands extends OsCommands {
return `${preCommand} && ${command}`;
}
}
public fileExistCommand(filePath: string): string {
return `test -e ${filePath} && echo True || echo False`;
}
}
export { LinuxCommands };
......@@ -130,6 +130,10 @@ class WindowsCommands extends OsCommands {
return `${preCommand} && set prePath=%path% && ${command}`;
}
}
public fileExistCommand(filePath: string): string {
return `powershell Test-Path ${filePath} -PathType Leaf`;
}
}
export { WindowsCommands };
......@@ -29,6 +29,7 @@ abstract class OsCommands {
public abstract extractFile(tarFileName: string, targetFolder: string): string;
public abstract executeScript(script: string, isFile: boolean): string;
public abstract addPreCommand(preCommand: string | undefined, command: string | undefined): string | undefined;
public abstract fileExistCommand(filePath: string): string | undefined;
public joinPath(...paths: string[]): string {
let dir: string = paths.filter((path: any) => path !== '').join(this.pathSpliter);
......
......@@ -238,6 +238,12 @@ class ShellExecutor {
return commandResult.exitCode == 0;
}
public async fileExist(filePath: string): Promise<boolean> {
const commandText = this.osCommands && this.osCommands.fileExistCommand(filePath);
const commandResult = await this.execute(commandText);
return commandResult.stdout !== undefined && commandResult.stdout.trim() === 'True';
}
public async extractFile(tarFileName: string, targetFolder: string): Promise<boolean> {
const commandText = this.osCommands && this.osCommands.extractFile(tarFileName, targetFolder);
const commandResult = await this.execute(commandText);
......
......@@ -137,40 +137,43 @@ export class RemoteEnvironmentService extends EnvironmentService {
private async refreshEnvironment(environment: EnvironmentInformation): Promise<void> {
const executor = await this.getExecutor(environment.id);
const jobpidPath: string = `${environment.runnerWorkingFolder}/pid`;
const runnerReturnCodeFilePath: string = `${environment.runnerWorkingFolder}/code`;
if (fs.existsSync(jobpidPath)) {
/* eslint-disable require-atomic-updates */
try {
const isAlive = await executor.isProcessAlive(jobpidPath);
// if the process of jobpid is not alive any more
if (!isAlive) {
const remoteEnvironment: RemoteMachineEnvironmentInformation = environment as RemoteMachineEnvironmentInformation;
if (remoteEnvironment.rmMachineMeta === undefined) {
throw new Error(`${remoteEnvironment.id} machine meta not initialized!`);
}
this.log.info(`pid in ${remoteEnvironment.rmMachineMeta.ip}:${jobpidPath} is not alive!`);
if (fs.existsSync(runnerReturnCodeFilePath)) {
const runnerReturnCode: string = await executor.getRemoteFileContent(runnerReturnCodeFilePath);
const match: RegExpMatchArray | null = runnerReturnCode.trim()
.match(/^-?(\d+)\s+(\d+)$/);
if (match !== null) {
const { 1: code } = match;
// Update trial job's status based on result code
if (parseInt(code, 10) === 0) {
environment.setStatus('SUCCEEDED');
} else {
environment.setStatus('FAILED');
}
this.releaseEnvironmentResource(environment);
}
const jobpidPath: string = `${environment.runnerWorkingFolder}/pid`;
const runnerReturnCodeFilePath: string = `${environment.runnerWorkingFolder}/code`;
/* eslint-disable require-atomic-updates */
try {
// check if pid file exist
const pidExist = await executor.fileExist(jobpidPath);
if (!pidExist) {
return;
}
const isAlive = await executor.isProcessAlive(jobpidPath);
environment.status = 'RUNNING';
// if the process of jobpid is not alive any more
if (!isAlive) {
const remoteEnvironment: RemoteMachineEnvironmentInformation = environment as RemoteMachineEnvironmentInformation;
if (remoteEnvironment.rmMachineMeta === undefined) {
throw new Error(`${remoteEnvironment.id} machine meta not initialized!`);
}
this.log.info(`pid in ${remoteEnvironment.rmMachineMeta.ip}:${jobpidPath} is not alive!`);
if (fs.existsSync(runnerReturnCodeFilePath)) {
const runnerReturnCode: string = await executor.getRemoteFileContent(runnerReturnCodeFilePath);
const match: RegExpMatchArray | null = runnerReturnCode.trim()
.match(/^-?(\d+)\s+(\d+)$/);
if (match !== null) {
const { 1: code } = match;
// Update trial job's status based on result code
if (parseInt(code, 10) === 0) {
environment.setStatus('SUCCEEDED');
} else {
environment.setStatus('FAILED');
}
this.releaseEnvironmentResource(environment);
}
} catch (error) {
this.releaseEnvironmentResource(environment);
this.log.error(`Update job status exception, error is ${error.message}`);
}
}
} catch (error) {
this.log.error(`Update job status exception, error is ${error.message}`);
}
}
public async refreshEnvironmentsStatus(environments: EnvironmentInformation[]): Promise<void> {
......@@ -245,6 +248,7 @@ export class RemoteEnvironmentService extends EnvironmentService {
'envs', environment.id)
environment.command = `cd ${environment.runnerWorkingFolder} && \
${environment.command} --job_pid_file ${environment.runnerWorkingFolder}/pid \
1>${environment.runnerWorkingFolder}/trialrunner_stdout 2>${environment.runnerWorkingFolder}/trialrunner_stderr \
&& echo $? \`date +%s%3N\` >${environment.runnerWorkingFolder}/code`;
return Promise.resolve(true);
}
......@@ -266,7 +270,6 @@ ${environment.command} --job_pid_file ${environment.runnerWorkingFolder}/pid \
// Execute command in remote machine
executor.executeScript(executor.joinPath(environment.runnerWorkingFolder,
executor.getScriptName("run")), true, false);
environment.status = 'RUNNING';
if (environment.rmMachineMeta === undefined) {
throw new Error(`${environment.id} rmMachineMeta not initialized!`);
}
......
......@@ -662,19 +662,22 @@ class TrialDispatcher implements TrainingService {
trial.status = "RUNNING";
await this.commandChannel.sendCommand(trial.environment, NEW_TRIAL_JOB, trial.settings);
}
/**
* release the trial assigned environment resources
* @param trial
*/
private releaseEnvironment(trial: TrialDetail): void {
if (undefined === trial.environment) {
throw new Error(`TrialDispatcher: environment is not assigned to trial ${trial.id}, and cannot be released!`);
}
if (trial.environment.runningTrialCount <= 0) {
throw new Error(`TrialDispatcher: environment ${trial.environment.id} has no counted running trial!`);
if (trial.environment !== undefined) {
if (trial.environment.runningTrialCount <= 0) {
throw new Error(`TrialDispatcher: environment ${trial.environment.id} has no counted running trial!`);
}
trial.environment.runningTrialCount--;
trial.environment = undefined;
}
if (true === this.enableGpuScheduler) {
this.gpuScheduler.removeGpuReservation(trial);
}
trial.environment.runningTrialCount--;
trial.environment = undefined;
}
private async handleMetricData(trialId: string, data: any): Promise<void> {
......
......@@ -1089,10 +1089,10 @@ cli-width@^2.0.0:
version "2.2.0"
resolved "https://registry.yarnpkg.com/cli-width/-/cli-width-2.2.0.tgz#ff19ede8a9a5e579324147b0c11f0fbcbabed639"
cliui@^7.0.0:
version "7.0.1"
resolved "https://registry.yarnpkg.com/cliui/-/cliui-7.0.1.tgz#a4cb67aad45cd83d8d05128fc9f4d8fbb887e6b3"
integrity sha512-rcvHOWyGyid6I1WjT/3NatKj2kDt9OdSHSXpyLXaMWFbKpGACNW8pRhhdPUq9MWUOdwn8Rz9AVETjF4105rZZQ==
cliui@^7.0.2:
version "7.0.3"
resolved "https://registry.yarnpkg.com/cliui/-/cliui-7.0.3.tgz#ef180f26c8d9bff3927ee52428bfec2090427981"
integrity sha512-Gj3QHTkVMPKqwP3f7B4KPkBZRMR9r4rfi5bXFpg1a+Svvj8l7q5CnkBkVQzfxT5DFSsGk2+PascOgL0JYkL2kw==
dependencies:
string-width "^4.2.0"
strip-ansi "^6.0.0"
......@@ -1329,7 +1329,7 @@ debug@^3.1.0:
dependencies:
ms "^2.1.1"
debuglog@^1.0.1:
debuglog@*, debuglog@^1.0.1:
version "1.0.1"
resolved "https://registry.yarnpkg.com/debuglog/-/debuglog-1.0.1.tgz#aa24ffb9ac3df9a2351837cfb2d279360cd78492"
......@@ -1599,10 +1599,10 @@ es6-promisify@^5.0.0:
dependencies:
es6-promise "^4.0.3"
escalade@^3.0.2:
version "3.1.0"
resolved "https://registry.yarnpkg.com/escalade/-/escalade-3.1.0.tgz#e8e2d7c7a8b76f6ee64c2181d6b8151441602d4e"
integrity sha512-mAk+hPSO8fLDkhV7V0dXazH5pDc6MrjBTPyD3VeKzxnVFjH1MIxbCdqGZB9O8+EwWakZs3ZCbDS4IpRt79V1ig==
escalade@^3.1.1:
version "3.1.1"
resolved "https://registry.yarnpkg.com/escalade/-/escalade-3.1.1.tgz#d8cfdc7000965c5a0174b4a82eaa5c0552742e40"
integrity sha512-k0er2gUkLf8O0zKJiAhmkTnJlTvINGv7ygDNPbeIsX/TJjGJZHuh9B2UxbsaEkmlEo9MfhrSzmhIlhRlI2GXnw==
escape-html@~1.0.3:
version "1.0.3"
......@@ -2375,7 +2375,7 @@ import-lazy@^2.1.0:
version "2.1.0"
resolved "https://registry.yarnpkg.com/import-lazy/-/import-lazy-2.1.0.tgz#05698e3d45c88e8d7e9d92cb0584e77f096f3e43"
imurmurhash@^0.1.4:
imurmurhash@*, imurmurhash@^0.1.4:
version "0.1.4"
resolved "https://registry.yarnpkg.com/imurmurhash/-/imurmurhash-0.1.4.tgz#9218b9b2b928a238b13dc4fb6b6d576f231453ea"
integrity sha1-khi5srkoojixPcT7a21XbyMUU+o=
......@@ -3057,6 +3057,11 @@ lockfile@^1.0.4:
dependencies:
signal-exit "^3.0.2"
lodash._baseindexof@*:
version "3.1.0"
resolved "https://registry.yarnpkg.com/lodash._baseindexof/-/lodash._baseindexof-3.1.0.tgz#fe52b53a1c6761e42618d654e4a25789ed61822c"
integrity sha1-/lK1OhxnYeQmGNZU5KJXie1hgiw=
lodash._baseuniq@~4.6.0:
version "4.6.0"
resolved "https://registry.yarnpkg.com/lodash._baseuniq/-/lodash._baseuniq-4.6.0.tgz#0ebb44e456814af7905c6212fa2c9b2d51b841e8"
......@@ -3064,10 +3069,32 @@ lodash._baseuniq@~4.6.0:
lodash._createset "~4.0.0"
lodash._root "~3.0.0"
lodash._bindcallback@*:
version "3.0.1"
resolved "https://registry.yarnpkg.com/lodash._bindcallback/-/lodash._bindcallback-3.0.1.tgz#e531c27644cf8b57a99e17ed95b35c748789392e"
integrity sha1-5THCdkTPi1epnhftlbNcdIeJOS4=
lodash._cacheindexof@*:
version "3.0.2"
resolved "https://registry.yarnpkg.com/lodash._cacheindexof/-/lodash._cacheindexof-3.0.2.tgz#3dc69ac82498d2ee5e3ce56091bafd2adc7bde92"
integrity sha1-PcaayCSY0u5ePOVgkbr9Ktx73pI=
lodash._createcache@*:
version "3.1.2"
resolved "https://registry.yarnpkg.com/lodash._createcache/-/lodash._createcache-3.1.2.tgz#56d6a064017625e79ebca6b8018e17440bdcf093"
integrity sha1-VtagZAF2JeeevKa4AY4XRAvc8JM=
dependencies:
lodash._getnative "^3.0.0"
lodash._createset@~4.0.0:
version "4.0.3"
resolved "https://registry.yarnpkg.com/lodash._createset/-/lodash._createset-4.0.3.tgz#0f4659fbb09d75194fa9e2b88a6644d363c9fe26"
lodash._getnative@*, lodash._getnative@^3.0.0:
version "3.9.1"
resolved "https://registry.yarnpkg.com/lodash._getnative/-/lodash._getnative-3.9.1.tgz#570bc7dede46d61cdcde687d65d3eecbaa3aaff5"
integrity sha1-VwvH3t5G1hzc3mh9ZdPuy6o6r/U=
lodash._root@~3.0.0:
version "3.0.1"
resolved "https://registry.yarnpkg.com/lodash._root/-/lodash._root-3.0.1.tgz#fba1c4524c19ee9a5f8136b4609f017cf4ded692"
......@@ -3116,6 +3143,11 @@ lodash.pick@^4.4.0:
version "4.4.0"
resolved "https://registry.yarnpkg.com/lodash.pick/-/lodash.pick-4.4.0.tgz#52f05610fff9ded422611441ed1fc123a03001b3"
lodash.restparam@*:
version "3.6.1"
resolved "https://registry.yarnpkg.com/lodash.restparam/-/lodash.restparam-3.6.1.tgz#936a4e309ef330a7645ed4145986c85ae5b20805"
integrity sha1-k2pOMJ7zMKdkXtQUWYbIWuWyCAU=
lodash.unescape@4.0.1:
version "4.0.1"
resolved "https://registry.yarnpkg.com/lodash.unescape/-/lodash.unescape-4.0.1.tgz#bf2249886ce514cda112fae9218cdc065211fc9c"
......@@ -3679,10 +3711,11 @@ npm-run-path@^2.0.0:
path-key "^2.0.0"
npm-user-validate@~1.0.0:
version "1.0.0"
resolved "https://registry.yarnpkg.com/npm-user-validate/-/npm-user-validate-1.0.0.tgz#8ceca0f5cea04d4e93519ef72d0557a75122e951"
version "1.0.1"
resolved "https://registry.yarnpkg.com/npm-user-validate/-/npm-user-validate-1.0.1.tgz#31428fc5475fe8416023f178c0ab47935ad8c561"
integrity sha512-uQwcd/tY+h1jnEaze6cdX/LrhWhoBxfSknxentoqmIuStxUExxjWd3ULMLFPiFUrZKbOVMowH6Jq2FRWfmhcEw==
npm@5.1.0, npm@^6.14.8:
npm@5.1.0, npm@>=6.14.8:
version "6.14.8"
resolved "https://registry.yarnpkg.com/npm/-/npm-6.14.8.tgz#64ef754345639bc035982ec3f609353c8539033c"
integrity sha512-HBZVBMYs5blsj94GTeQZel7s9odVuuSUHy1+AlZh7rPVux1os2ashvEGLy/STNK7vUjbrCg5Kq9/GXisJgdf6A==
......@@ -3705,6 +3738,7 @@ npm@5.1.0, npm@^6.14.8:
cmd-shim "^3.0.3"
columnify "~1.5.4"
config-chain "^1.1.12"
debuglog "*"
detect-indent "~5.0.0"
detect-newline "^2.1.0"
dezalgo "~1.0.3"
......@@ -3719,6 +3753,7 @@ npm@5.1.0, npm@^6.14.8:
has-unicode "~2.0.1"
hosted-git-info "^2.8.8"
iferr "^1.0.2"
imurmurhash "*"
infer-owner "^1.0.4"
inflight "~1.0.6"
inherits "^2.0.4"
......@@ -3737,8 +3772,14 @@ npm@5.1.0, npm@^6.14.8:
libnpx "^10.2.4"
lock-verify "^2.1.0"
lockfile "^1.0.4"
lodash._baseindexof "*"
lodash._baseuniq "~4.6.0"
lodash._bindcallback "*"
lodash._cacheindexof "*"
lodash._createcache "*"
lodash._getnative "*"
lodash.clonedeep "~4.5.0"
lodash.restparam "*"
lodash.union "~4.6.0"
lodash.uniq "~4.5.0"
lodash.without "~4.4.0"
......@@ -5664,10 +5705,10 @@ y18n@^4.0.0:
resolved "https://registry.yarnpkg.com/y18n/-/y18n-4.0.0.tgz#95ef94f85ecc81d007c264e190a120f0a3c8566b"
integrity sha512-r9S/ZyXu/Xu9q1tYlpsLIsa3EeLXXk0VwlxqTcFRfg9EhMW+17kbt9G0NrgCmhGb5vT2hyhJZLfDGx+7+5Uj/w==
y18n@^5.0.1:
version "5.0.1"
resolved "https://registry.yarnpkg.com/y18n/-/y18n-5.0.1.tgz#1ad2a7eddfa8bce7caa2e1f6b5da96c39d99d571"
integrity sha512-/jJ831jEs4vGDbYPQp4yGKDYPSCCEQ45uZWJHE1AoYBzqdZi8+LDWas0z4HrmJXmKdpFsTiowSHXdxyFhpmdMg==
y18n@^5.0.2:
version "5.0.4"
resolved "https://registry.yarnpkg.com/y18n/-/y18n-5.0.4.tgz#0ab2db89dd5873b5ec4682d8e703e833373ea897"
integrity sha512-deLOfD+RvFgrpAmSZgfGdWYE+OKyHcVHaRQ7NphG/63scpRvTHHeQMAxGGvaLVGJ+HYVcCXlzcTK0ZehFf+eHQ==
yallist@^2.1.2:
version "2.1.2"
......@@ -5686,10 +5727,10 @@ yallist@^4.0.0:
version "4.0.0"
resolved "https://registry.yarnpkg.com/yallist/-/yallist-4.0.0.tgz#9bb92790d9c0effec63be73519e11a35019a3a72"
yargs-parser@13.1.2, yargs-parser@^20.0.0, yargs-parser@^20.2.0:
version "20.2.0"
resolved "https://registry.yarnpkg.com/yargs-parser/-/yargs-parser-20.2.0.tgz#944791ca2be2e08ddadd3d87e9de4c6484338605"
integrity sha512-2agPoRFPoIcFzOIp6656gcvsg2ohtscpw2OINr/q46+Sq41xz2OYLqx5HRHabmFU1OARIPAYH5uteICE7mn/5A==
yargs-parser@13.1.2, yargs-parser@>=20.2.0, yargs-parser@^20.2.2:
version "20.2.3"
resolved "https://registry.yarnpkg.com/yargs-parser/-/yargs-parser-20.2.3.tgz#92419ba867b858c868acf8bae9bf74af0dd0ce26"
integrity sha512-emOFRT9WVHw03QSvN5qor9QQT9+sw5vwxfYweivSMHTcAXPefwVae2FjO7JJjj8hCE4CzPOPeFM83VwT29HCww==
yargs-unparser@1.6.1:
version "1.6.1"
......@@ -5702,18 +5743,18 @@ yargs-unparser@1.6.1:
is-plain-obj "^1.1.0"
yargs "^14.2.3"
yargs@13.3.2, yargs@^11.0.0, yargs@^14.2.3, yargs@^15.0.2, yargs@^16.0.3, yargs@^8.0.2:
version "16.0.3"
resolved "https://registry.yarnpkg.com/yargs/-/yargs-16.0.3.tgz#7a919b9e43c90f80d4a142a89795e85399a7e54c"
integrity sha512-6+nLw8xa9uK1BOEOykaiYAJVh6/CjxWXK/q9b5FpRgNslt8s22F2xMBqVIKgCRjNgGvGPBy8Vog7WN7yh4amtA==
yargs@13.3.2, yargs@>=16.0.3, yargs@^11.0.0, yargs@^14.2.3, yargs@^15.0.2, yargs@^8.0.2:
version "16.1.0"
resolved "https://registry.yarnpkg.com/yargs/-/yargs-16.1.0.tgz#fc333fe4791660eace5a894b39d42f851cd48f2a"
integrity sha512-upWFJOmDdHN0syLuESuvXDmrRcWd1QafJolHskzaw79uZa7/x53gxQKiR07W59GWY1tFhhU/Th9DrtSfpS782g==
dependencies:
cliui "^7.0.0"
escalade "^3.0.2"
cliui "^7.0.2"
escalade "^3.1.1"
get-caller-file "^2.0.5"
require-directory "^2.1.1"
string-width "^4.2.0"
y18n "^5.0.1"
yargs-parser "^20.0.0"
y18n "^5.0.2"
yargs-parser "^20.2.2"
yn@^2.0.0:
version "2.0.0"
......
......@@ -17,6 +17,11 @@ log_level_map = {
_time_format = '%m/%d/%Y, %I:%M:%S %p'
# FIXME
# This hotfix the bug that querying installed tuners with `package_utils` will activate dispatcher logger.
# This behavior depends on underlying implementation of `nnictl` and is likely to break in future.
_logger_initialized = False
class _LoggerFileWrapper(TextIOBase):
def __init__(self, logger_file):
self.file = logger_file
......@@ -33,6 +38,11 @@ def init_logger(logger_file_path, log_level_name='info'):
This will redirect anything from logging.getLogger() as well as stdout to specified file.
logger_file_path: path of logger file (path-like object).
"""
global _logger_initialized
if _logger_initialized:
return
_logger_initialized = True
log_level = log_level_map.get(log_level_name, logging.INFO)
logger_file = open(logger_file_path, 'w')
fmt = '[%(asctime)s] %(levelname)s (%(name)s/%(threadName)s) %(message)s'
......@@ -55,6 +65,11 @@ def init_standalone_logger():
Initialize root logger for standalone mode.
This will set NNI's log level to INFO and print its log to stdout.
"""
global _logger_initialized
if _logger_initialized:
return
_logger_initialized = True
fmt = '[%(asctime)s] %(levelname)s (%(name)s) %(message)s'
formatter = logging.Formatter(fmt, _time_format)
handler = logging.StreamHandler(sys.stdout)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment