"...resnet50_onnxruntime_migraphx.git" did not exist on "17fa8aed578eaa9daa87dcaf6223a319963fe19a"
Unverified Commit bf8be1e7 authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

Merge pull request #2837 from microsoft/v1.8

Merge v1.8 back to master
parents 320407b1 e06a9dda
...@@ -491,7 +491,7 @@ class LocalTrainingService implements TrainingService { ...@@ -491,7 +491,7 @@ class LocalTrainingService implements TrainingService {
if (process.platform === 'win32') { if (process.platform === 'win32') {
script.push(`cd $env:NNI_CODE_DIR`); script.push(`cd $env:NNI_CODE_DIR`);
script.push( script.push(
`cmd.exe /c ${localTrialConfig.command} 2>"${path.join(workingDirectory, 'stderr')}"`, `cmd.exe /c ${localTrialConfig.command} 2>&1 | Out-File "${path.join(workingDirectory, 'stderr')}" -encoding utf8`,
`$NOW_DATE = [int64](([datetime]::UtcNow)-(get-date "1/1/1970")).TotalSeconds`, `$NOW_DATE = [int64](([datetime]::UtcNow)-(get-date "1/1/1970")).TotalSeconds`,
`$NOW_DATE = "$NOW_DATE" + (Get-Date -Format fff).ToString()`, `$NOW_DATE = "$NOW_DATE" + (Get-Date -Format fff).ToString()`,
`Write $LASTEXITCODE " " $NOW_DATE | Out-File "${path.join(workingDirectory, '.nni', 'state')}" -NoNewline -encoding utf8`); `Write $LASTEXITCODE " " $NOW_DATE | Out-File "${path.join(workingDirectory, '.nni', 'state')}" -NoNewline -encoding utf8`);
...@@ -523,6 +523,8 @@ class LocalTrainingService implements TrainingService { ...@@ -523,6 +523,8 @@ class LocalTrainingService implements TrainingService {
const runScriptContent: string[] = []; const runScriptContent: string[] = [];
if (process.platform !== 'win32') { if (process.platform !== 'win32') {
runScriptContent.push('#!/bin/bash'); runScriptContent.push('#!/bin/bash');
} else {
runScriptContent.push(`$env:PATH="${process.env.path}"`)
} }
for (const variable of variables) { for (const variable of variables) {
runScriptContent.push(setEnvironmentVariable(variable)); runScriptContent.push(setEnvironmentVariable(variable));
......
...@@ -87,6 +87,21 @@ class RemoteMachineTrainingService implements TrainingService { ...@@ -87,6 +87,21 @@ class RemoteMachineTrainingService implements TrainingService {
this.log.info('ssh connection initialized!'); this.log.info('ssh connection initialized!');
// set sshConnectionPromises to [] to avoid log information duplicated // set sshConnectionPromises to [] to avoid log information duplicated
this.sshConnectionPromises = []; this.sshConnectionPromises = [];
// initialize gpuScheduler
this.gpuScheduler = new GPUScheduler(this.machineExecutorManagerMap);
if (this.trialConfig === undefined) {
throw new Error("trial config not initialized!");
}
// Copy codeDir to remote machine
for (const [rmMeta, executorManager] of this.machineExecutorManagerMap.entries()) {
const executor: ShellExecutor = await executorManager.getExecutor(this.initExecutorId);
if (executor !== undefined) {
this.machineCopyExpCodeDirPromiseMap.set(
rmMeta,
executor.copyDirectoryToRemote(this.trialConfig.codeDir, executor.getRemoteCodePath(getExperimentId()))
);
}
}
} }
while (!this.stopping) { while (!this.stopping) {
while (this.jobQueue.length > 0) { while (this.jobQueue.length > 0) {
...@@ -310,7 +325,6 @@ class RemoteMachineTrainingService implements TrainingService { ...@@ -310,7 +325,6 @@ class RemoteMachineTrainingService implements TrainingService {
break; break;
case TrialConfigMetadataKey.MACHINE_LIST: case TrialConfigMetadataKey.MACHINE_LIST:
await this.setupConnections(value); await this.setupConnections(value);
this.gpuScheduler = new GPUScheduler(this.machineExecutorManagerMap);
break; break;
case TrialConfigMetadataKey.TRIAL_CONFIG: { case TrialConfigMetadataKey.TRIAL_CONFIG: {
const remoteMachineTrailConfig: TrialConfig = <TrialConfig>JSON.parse(value); const remoteMachineTrailConfig: TrialConfig = <TrialConfig>JSON.parse(value);
...@@ -327,20 +341,8 @@ class RemoteMachineTrainingService implements TrainingService { ...@@ -327,20 +341,8 @@ class RemoteMachineTrainingService implements TrainingService {
try { try {
// Validate to make sure codeDir doesn't have too many files // Validate to make sure codeDir doesn't have too many files
await validateCodeDir(remoteMachineTrailConfig.codeDir); await validateCodeDir(remoteMachineTrailConfig.codeDir);
// Copy codeDir to remote machine
for (const [rmMeta, executorManager] of this.machineExecutorManagerMap.entries()) {
const executor: ShellExecutor = await executorManager.getExecutor(this.initExecutorId);
if (executor !== undefined) {
this.machineCopyExpCodeDirPromiseMap.set(
rmMeta,
executor.copyDirectoryToRemote(remoteMachineTrailConfig.codeDir, executor.getRemoteCodePath(getExperimentId()))
);
}
}
} catch (error) { } catch (error) {
this.log.error(error); this.log.error(error);
return Promise.reject(new Error(error)); return Promise.reject(new Error(error));
} }
...@@ -426,19 +428,19 @@ class RemoteMachineTrainingService implements TrainingService { ...@@ -426,19 +428,19 @@ class RemoteMachineTrainingService implements TrainingService {
const rmMetaList: RemoteMachineMeta[] = <RemoteMachineMeta[]>JSON.parse(machineList); const rmMetaList: RemoteMachineMeta[] = <RemoteMachineMeta[]>JSON.parse(machineList);
for (const rmMeta of rmMetaList) { for (const rmMeta of rmMetaList) {
rmMeta.occupiedGpuIndexMap = new Map<number, number>(); this.sshConnectionPromises.push(this.initRemoteMachineOnConnected(rmMeta));
const executorManager: ExecutorManager = new ExecutorManager(rmMeta);
this.log.info(`connecting to ${rmMeta.username}@${rmMeta.ip}:${rmMeta.port}`);
const executor: ShellExecutor = await executorManager.getExecutor(this.initExecutorId);
this.log.debug(`reached ${executor.name}`);
this.machineExecutorManagerMap.set(rmMeta, executorManager);
this.log.debug(`initializing ${executor.name}`);
this.sshConnectionPromises.push(this.initRemoteMachineOnConnected(rmMeta, executor));
this.log.info(`connecting to ${executor.name}`);
} }
} }
private async initRemoteMachineOnConnected(rmMeta: RemoteMachineMeta, executor: ShellExecutor): Promise<void> { private async initRemoteMachineOnConnected(rmMeta: RemoteMachineMeta): Promise<void> {
rmMeta.occupiedGpuIndexMap = new Map<number, number>();
const executorManager: ExecutorManager = new ExecutorManager(rmMeta);
this.log.info(`connecting to ${rmMeta.username}@${rmMeta.ip}:${rmMeta.port}`);
const executor: ShellExecutor = await executorManager.getExecutor(this.initExecutorId);
this.log.debug(`reached ${executor.name}`);
this.machineExecutorManagerMap.set(rmMeta, executorManager);
this.log.debug(`initializing ${executor.name}`);
// Create root working directory after executor is ready // Create root working directory after executor is ready
const nniRootDir: string = executor.joinPath(executor.getTempPath(), 'nni'); const nniRootDir: string = executor.joinPath(executor.getTempPath(), 'nni');
await executor.createFolder(executor.getRemoteExperimentRootDir(getExperimentId())); await executor.createFolder(executor.getRemoteExperimentRootDir(getExperimentId()));
......
...@@ -74,13 +74,11 @@ export class AMLClient { ...@@ -74,13 +74,11 @@ export class AMLClient {
throw Error('python shell client not initialized!'); throw Error('python shell client not initialized!');
} }
this.pythonShellClient.send('tracking_url'); this.pythonShellClient.send('tracking_url');
let trackingUrl = ''; this.pythonShellClient.on('message', (status: any) => {
this.pythonShellClient.on('message', function (status: any) { const trackingUrl = this.parseContent('tracking_url', status);
const items = status.split(':'); if (trackingUrl !== '') {
if (items[0] === 'tracking_url') { deferred.resolve(trackingUrl);
trackingUrl = items.splice(1, items.length).join('')
} }
deferred.resolve(trackingUrl);
}); });
this.monitorError(this.pythonShellClient, deferred); this.monitorError(this.pythonShellClient, deferred);
return deferred.promise; return deferred.promise;
...@@ -91,12 +89,11 @@ export class AMLClient { ...@@ -91,12 +89,11 @@ export class AMLClient {
if (this.pythonShellClient === undefined) { if (this.pythonShellClient === undefined) {
throw Error('python shell client not initialized!'); throw Error('python shell client not initialized!');
} }
let newStatus = oldStatus;
this.pythonShellClient.send('update_status'); this.pythonShellClient.send('update_status');
this.pythonShellClient.on('message', function (status: any) { this.pythonShellClient.on('message', (status: any) => {
const items = status.split(':'); let newStatus = this.parseContent('status', status);
if (items[0] === 'status') { if (newStatus === '') {
newStatus = items.splice(1, items.length).join('') newStatus = oldStatus;
} }
deferred.resolve(newStatus); deferred.resolve(newStatus);
}); });
...@@ -117,10 +114,10 @@ export class AMLClient { ...@@ -117,10 +114,10 @@ export class AMLClient {
throw Error('python shell client not initialized!'); throw Error('python shell client not initialized!');
} }
this.pythonShellClient.send('receive'); this.pythonShellClient.send('receive');
this.pythonShellClient.on('message', function (command: any) { this.pythonShellClient.on('message', (command: any) => {
const items = command.split(':') const message = this.parseContent('receive', command);
if (items[0] === 'receive') { if (message !== '') {
deferred.resolve(JSON.parse(command.slice(8))) deferred.resolve(JSON.parse(message))
} }
}); });
this.monitorError(this.pythonShellClient, deferred); this.monitorError(this.pythonShellClient, deferred);
...@@ -136,4 +133,13 @@ export class AMLClient { ...@@ -136,4 +133,13 @@ export class AMLClient {
deferred.reject(error); deferred.reject(error);
}); });
} }
// Parse command content, command format is {head}:{content}
public parseContent(head: string, command: string): string {
const items = command.split(':');
if (items[0] === head) {
return command.slice(head.length + 1);
}
return '';
}
} }
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
import * as chai from 'chai';
import { cleanupUnitTest, prepareUnitTest } from '../../../common/utils';
import chaiAsPromised = require("chai-as-promised");
import { AMLClient } from '../aml/amlClient';
describe('Unit Test for amlClient', () => {
before(() => {
chai.should();
chai.use(chaiAsPromised);
prepareUnitTest();
});
after(() => {
cleanupUnitTest();
});
it('test parseContent', async () => {
let amlClient: AMLClient = new AMLClient('', '', '', '', '', '', '', '');
chai.assert.equal(amlClient.parseContent('test', 'test:1234'), '1234', "The content should be 1234");
chai.assert.equal(amlClient.parseContent('test', 'abcd:1234'), '', "The content should be null");
});
});
...@@ -6,7 +6,10 @@ Abstract base classes for TensorFlow model compression. ...@@ -6,7 +6,10 @@ Abstract base classes for TensorFlow model compression.
""" """
import logging import logging
import tensorflow as tf import tensorflow as tf
assert tf.__version__.startswith('2'), 'NNI model compression only supports TensorFlow v2.x'
from . import default_layers from . import default_layers
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
...@@ -25,9 +28,9 @@ class LayerInfo: ...@@ -25,9 +28,9 @@ class LayerInfo:
The layer's name. Note that it's local to sub-model and may differ from its attribute name. The layer's name. Note that it's local to sub-model and may differ from its attribute name.
type : str type : str
Name of the layer's class. Name of the layer's class.
path : list of str/int path : list of str or tuple of (str, int)
The layer object's and its parents' attribute name / list index. The layer object's and its parents' attribute name / list index.
For example, if the path is `['cells', 2, 'conv']`, then the layer can be accessed as `model.cells[2].conv`. For example, if the path is `[('cells', 2), 'conv']`, then the layer can be accessed as `model.cells[2].conv`.
config : JSON object config : JSON object
Selected configuration for this layer. The format is detailed in tutorial. Selected configuration for this layer. The format is detailed in tutorial.
...@@ -35,7 +38,7 @@ class LayerInfo: ...@@ -35,7 +38,7 @@ class LayerInfo:
---------- ----------
layer : tf.keras.layers.Layer layer : tf.keras.layers.Layer
See attributes section. See attributes section.
path : list of str/int path : list of str or tuple of (str, int)
See attributes section. See attributes section.
""" """
...@@ -75,6 +78,8 @@ class Compressor: ...@@ -75,6 +78,8 @@ class Compressor:
def __init__(self, LayerWrapperClass, model, config_list): def __init__(self, LayerWrapperClass, model, config_list):
assert isinstance(model, tf.keras.Model) assert isinstance(model, tf.keras.Model)
if isinstance(model, tf.keras.Sequential):
raise ValueError('NNI model compression does not support `Sequential` model for now')
self.validate_config(model, config_list) self.validate_config(model, config_list)
self.bound_model = model self.bound_model = model
...@@ -204,10 +209,12 @@ class PrunerLayerWrapper(tf.keras.Model): ...@@ -204,10 +209,12 @@ class PrunerLayerWrapper(tf.keras.Model):
for weight in self.layer.weights: for weight in self.layer.weights:
mask = self.masks.get(weight.name) mask = self.masks.get(weight.name)
if mask is not None: if mask is not None:
new_weights.append(tf.math.multiply(weight, mask).numpy()) new_weights.append(tf.math.multiply(weight, mask))
else: else:
new_weights.append(weight.numpy()) new_weights.append(weight)
self.layer.set_weights(new_weights) if new_weights and not hasattr(new_weights[0], 'numpy'):
raise RuntimeError('NNI: Compressed model can only run in eager mode')
self.layer.set_weights([weight.numpy() for weight in new_weights])
return self.layer(*inputs) return self.layer(*inputs)
...@@ -244,26 +251,21 @@ def _locate_layers(model, cur_path=[]): ...@@ -244,26 +251,21 @@ def _locate_layers(model, cur_path=[]):
# and to my knowledge `Layer.name` is only useful for read-only access. # and to my knowledge `Layer.name` is only useful for read-only access.
# `cur_path`s format is documented in `LayerInfo.path`. # `cur_path`s format is documented in `LayerInfo.path`.
# TODO: it can only find layers in `Model` and `list` for now. # TODO: it can only find layers in `Model` and `list` for now.
assert isinstance(model, tf.keras.Model)
if isinstance(model, tf.keras.Sequential):
_logger.warning('`Sequential` model is not supported yet, ignored.')
ret = {} ret = {}
for key, value in model.__dict__.items():
if isinstance(model, tf.keras.Model): if isinstance(value, tf.keras.Model):
for key, value in model.__dict__.items(): ret.update(_locate_layers(value, cur_path + [key]))
if isinstance(value, tf.keras.Model): elif isinstance(value, tf.keras.layers.Layer):
ret.update(_locate_layers(value, cur_path + [key])) ret[id(value)] = LayerInfo(value, cur_path + [key])
elif isinstance(value, list): elif isinstance(value, list):
ret.update(_locate_layers(value, cur_path + [key])) for i, item in enumerate(value):
elif isinstance(value, tf.keras.layers.Layer): if isinstance(item, tf.keras.Model):
ret[id(value)] = LayerInfo(value, cur_path + [key]) ret.update(_locate_layers(item, cur_path + [(key, i)]))
elif isinstance(item, tf.keras.layers.Layer):
elif isinstance(model, list): ret[id(item)] = LayerInfo(item, cur_path + [(key, i)])
for i, item in enumerate(model):
if isinstance(item, tf.keras.Model):
ret.update(_locate_layers(item, cur_path + [i]))
elif isinstance(item, tf.keras.layers.Layer):
ret[id(item)] = LayerInfo(item, cur_path + [i])
else:
raise ValueError('Unexpected model type: {}'.format(type(model)))
return ret return ret
def _select_config(layer_info, config_list): def _select_config(layer_info, config_list):
...@@ -289,12 +291,17 @@ def _instrument_model(model, wrappers): ...@@ -289,12 +291,17 @@ def _instrument_model(model, wrappers):
for wrapper in reversed(wrappers): for wrapper in reversed(wrappers):
cur = model cur = model
for key in wrapper.layer_info.path[:-1]: for key in wrapper.layer_info.path[:-1]:
if isinstance(key, int): if isinstance(key, str):
cur = cur[key]
else:
cur = getattr(cur, key) cur = getattr(cur, key)
else:
name, index = key
cur = getattr(cur, name)[index]
key = wrapper.layer_info.path[-1] key = wrapper.layer_info.path[-1]
if isinstance(key, int): if isinstance(key, str):
cur[key] = wrapper
else:
setattr(cur, key, wrapper) setattr(cur, key, wrapper)
else:
name, index = key
getattr(cur, name)[index] = wrapper
#if isinstance(cur, tf.keras.Sequential):
# cur._graph_initialized = False
# cur._layer_call_argspecs[wrapper] = cur._layer_call_argspecs[wrapper.layer]
...@@ -44,20 +44,24 @@ class LevelPrunerMasker(WeightMasker): ...@@ -44,20 +44,24 @@ class LevelPrunerMasker(WeightMasker):
def calc_masks(self, sparsity, wrapper, wrapper_idx=None): def calc_masks(self, sparsity, wrapper, wrapper_idx=None):
masks = {} masks = {}
for weight_variable in wrapper.layer.weights: for weight_variable in wrapper.layer.weights:
if weight_variable.name == 'bias': if 'bias' in weight_variable.name:
continue continue
k = int(tf.size(weight_variable).numpy() * sparsity) num_prune = int(tf.size(weight_variable).numpy() * sparsity)
if k == 0: if num_prune == 0:
continue continue
weight = weight_variable.read_value() weight = weight_variable.read_value()
if wrapper.masks.get(weight_variable.name) is not None: if wrapper.masks.get(weight_variable.name) is not None:
weight = tf.math.multiply(weight, wrapper.masks[weight_variable.name]) weight = tf.math.multiply(weight, wrapper.masks[weight_variable.name])
w_abs = tf.math.abs(tf.reshape(weight, [-1])) w_abs = tf.math.abs(weight)
threshold = tf.math.top_k(w_abs, k)[0][0] k = tf.size(weight) - num_prune
mask = tf.math.greater(w_abs, threshold) topk = tf.math.top_k(tf.reshape(w_abs, [-1]), k)[0]
if tf.size(topk) == 0:
mask = tf.zeros_like(weight)
else:
mask = tf.math.greater_equal(w_abs, topk[-1])
masks[weight_variable.name] = tf.cast(mask, weight.dtype) masks[weight_variable.name] = tf.cast(mask, weight.dtype)
return masks return masks
......
...@@ -17,7 +17,7 @@ from ..utils.sensitivity_analysis import SensitivityAnalysis ...@@ -17,7 +17,7 @@ from ..utils.sensitivity_analysis import SensitivityAnalysis
MAX_PRUNE_RATIO_PER_ITER = 0.95 MAX_PRUNE_RATIO_PER_ITER = 0.95
_logger = logging.getLogger('Sensitivity_Pruner') _logger = logging.getLogger('Sensitivity_Pruner')
_logger.setLevel(logging.INFO)
class SensitivityPruner(Pruner): class SensitivityPruner(Pruner):
""" """
...@@ -202,10 +202,10 @@ class SensitivityPruner(Pruner): ...@@ -202,10 +202,10 @@ class SensitivityPruner(Pruner):
prune_ratios = sorted(sensitivities[layer].keys()) prune_ratios = sorted(sensitivities[layer].keys())
last_ratio = 0 last_ratio = 0
for ratio in prune_ratios: for ratio in prune_ratios:
last_ratio = ratio
cur_acc = sensitivities[layer][ratio] cur_acc = sensitivities[layer][ratio]
if cur_acc + threshold < ori_acc: if cur_acc + threshold < ori_acc:
break break
last_ratio = ratio
max_ratio[layer] = last_ratio max_ratio[layer] = last_ratio
return max_ratio return max_ratio
...@@ -244,6 +244,7 @@ class SensitivityPruner(Pruner): ...@@ -244,6 +244,7 @@ class SensitivityPruner(Pruner):
# MAX_PRUNE_RATIO_PER_ITER we rescal all prune # MAX_PRUNE_RATIO_PER_ITER we rescal all prune
# ratios under this threshold # ratios under this threshold
if _Max > MAX_PRUNE_RATIO_PER_ITER: if _Max > MAX_PRUNE_RATIO_PER_ITER:
for layername in ratios: for layername in ratios:
ratios[layername] = ratios[layername] * \ ratios[layername] = ratios[layername] * \
MAX_PRUNE_RATIO_PER_ITER / _Max MAX_PRUNE_RATIO_PER_ITER / _Max
...@@ -317,6 +318,7 @@ class SensitivityPruner(Pruner): ...@@ -317,6 +318,7 @@ class SensitivityPruner(Pruner):
finetune_kwargs = {} finetune_kwargs = {}
if self.ori_acc is None: if self.ori_acc is None:
self.ori_acc = self.evaluator(*eval_args, **eval_kwargs) self.ori_acc = self.evaluator(*eval_args, **eval_kwargs)
assert isinstance(self.ori_acc, float) or isinstance(self.ori_acc, int)
if not resume_sensitivity: if not resume_sensitivity:
self.sensitivities = self.analyzer.analysis( self.sensitivities = self.analyzer.analysis(
val_args=eval_args, val_kwargs=eval_kwargs) val_args=eval_args, val_kwargs=eval_kwargs)
...@@ -330,6 +332,7 @@ class SensitivityPruner(Pruner): ...@@ -330,6 +332,7 @@ class SensitivityPruner(Pruner):
iteration_count = 0 iteration_count = 0
if self.checkpoint_dir is not None: if self.checkpoint_dir is not None:
os.makedirs(self.checkpoint_dir, exist_ok=True) os.makedirs(self.checkpoint_dir, exist_ok=True)
modules_wrapper_final = None
while cur_ratio > target_ratio: while cur_ratio > target_ratio:
iteration_count += 1 iteration_count += 1
# Each round have three steps: # Each round have three steps:
...@@ -343,9 +346,16 @@ class SensitivityPruner(Pruner): ...@@ -343,9 +346,16 @@ class SensitivityPruner(Pruner):
# layers according to the sensitivity result # layers according to the sensitivity result
proportion = self.sparsity_proportion_calc( proportion = self.sparsity_proportion_calc(
ori_acc, self.acc_drop_threshold, self.sensitivities) ori_acc, self.acc_drop_threshold, self.sensitivities)
new_pruneratio = self.normalize(proportion, self.sparsity_per_iter) new_pruneratio = self.normalize(proportion, self.sparsity_per_iter)
cfg_list = self.create_cfg(new_pruneratio) cfg_list = self.create_cfg(new_pruneratio)
if not cfg_list:
_logger.error('The threshold is too small, please set a larger threshold')
return self.model
_logger.debug('Pruner Config: %s', str(cfg_list)) _logger.debug('Pruner Config: %s', str(cfg_list))
cfg_str = ['%s:%.3f'%(cfg['op_names'][0], cfg['sparsity']) for cfg in cfg_list]
_logger.info('Current Sparsities: %s', ','.join(cfg_str))
pruner = self.Pruner(self.model, cfg_list) pruner = self.Pruner(self.model, cfg_list)
pruner.compress() pruner.compress()
pruned_acc = self.evaluator(*eval_args, **eval_kwargs) pruned_acc = self.evaluator(*eval_args, **eval_kwargs)
...@@ -367,6 +377,7 @@ class SensitivityPruner(Pruner): ...@@ -367,6 +377,7 @@ class SensitivityPruner(Pruner):
self.analyzer.already_pruned[name] = sparsity self.analyzer.already_pruned[name] = sparsity
# update the cur_ratio # update the cur_ratio
cur_ratio = 1 - self.current_sparsity() cur_ratio = 1 - self.current_sparsity()
modules_wrapper_final = pruner.get_modules_wrapper()
del pruner del pruner
_logger.info('Currently remained weights: %f', cur_ratio) _logger.info('Currently remained weights: %f', cur_ratio)
...@@ -383,14 +394,19 @@ class SensitivityPruner(Pruner): ...@@ -383,14 +394,19 @@ class SensitivityPruner(Pruner):
with open(cfg_path, 'w') as jf: with open(cfg_path, 'w') as jf:
json.dump(cfg_list, jf) json.dump(cfg_list, jf)
self.analyzer.export(sensitivity_path) self.analyzer.export(sensitivity_path)
if cur_ratio > target_ratio: if cur_ratio > target_ratio:
# If this is the last prune iteration, skip the time-consuming # If this is the last prune iteration, skip the time-consuming
# sensitivity analysis # sensitivity analysis
self.analyzer.load_state_dict(self.model.state_dict()) self.analyzer.load_state_dict(self.model.state_dict())
self.sensitivities = self.analyzer.analysis( self.sensitivities = self.analyzer.analysis(
val_args=eval_args, val_kwargs=eval_kwargs) val_args=eval_args, val_kwargs=eval_kwargs)
_logger.info('After Pruning: %.2f weights remains', cur_ratio) _logger.info('After Pruning: %.2f weights remains', cur_ratio)
self.modules_wrapper = modules_wrapper_final
self._wrap_model()
return self.model return self.model
def calc_mask(self, wrapper, **kwargs): def calc_mask(self, wrapper, **kwargs):
......
...@@ -222,6 +222,10 @@ infer_from_inshape = { ...@@ -222,6 +222,10 @@ infer_from_inshape = {
'ReLU': lambda module_masks, mask: relu_inshape(module_masks, mask), 'ReLU': lambda module_masks, mask: relu_inshape(module_masks, mask),
'ReLU6': lambda module_masks, mask: relu_inshape(module_masks, mask), 'ReLU6': lambda module_masks, mask: relu_inshape(module_masks, mask),
'aten::relu': lambda module_masks, mask: relu_inshape(module_masks, mask), 'aten::relu': lambda module_masks, mask: relu_inshape(module_masks, mask),
'aten::tanh': lambda module_masks, mask: relu_inshape(module_masks, mask),
'aten::tanh_': lambda module_masks, mask: relu_inshape(module_masks, mask),
'aten::hardtanh': lambda module_masks, mask: relu_inshape(module_masks, mask),
'aten::hardtanh_': lambda module_masks, mask: relu_inshape(module_masks, mask),
'aten::relu_': lambda module_masks, mask: relu_inshape(module_masks, mask), 'aten::relu_': lambda module_masks, mask: relu_inshape(module_masks, mask),
'Conv2d': lambda module_masks, mask: conv2d_inshape(module_masks, mask), 'Conv2d': lambda module_masks, mask: conv2d_inshape(module_masks, mask),
'MaxPool2d': lambda module_masks, mask: maxpool2d_inshape(module_masks, mask), 'MaxPool2d': lambda module_masks, mask: maxpool2d_inshape(module_masks, mask),
...@@ -282,7 +286,7 @@ def cat_inshape(module_masks, mask, cat_info, last_visited): ...@@ -282,7 +286,7 @@ def cat_inshape(module_masks, mask, cat_info, last_visited):
Parameters Parameters
---------- ----------
module_masks : ModuleMasks module_masks : ModuleMasks
The ModuleMasks instance of the batchnorm2d The ModuleMasks instance of the Conv2d
mask : CoarseMask mask : CoarseMask
The mask of its input tensor The mask of its input tensor
cat_info: dict cat_info: dict
......
...@@ -118,11 +118,14 @@ class CatMaskPadding(MaskFix): ...@@ -118,11 +118,14 @@ class CatMaskPadding(MaskFix):
continue continue
# pad the mask for the non-pruned layers # pad the mask for the non-pruned layers
for layer in layers: for layer in layers:
if layer in self.masks:
continue
module = name_to_module[layer] module = name_to_module[layer]
w_shape = module.weight.data.size() w_shape = module.weight.data.size()
w_mask = torch.ones(w_shape).to(device) w_mask = torch.ones(w_shape).to(device)
b_mask = None b_mask = None
if hasattr(module, 'bias'): if hasattr(module, 'bias') and module.bias is not None:
# module.bias may be None
b_shape = module.bias.data.size() b_shape = module.bias.data.size()
b_mask = torch.ones(b_shape).to(device) b_mask = torch.ones(b_shape).to(device)
self.masks[layer] = {'weight':w_mask, 'bias':b_mask} self.masks[layer] = {'weight':w_mask, 'bias':b_mask}
......
...@@ -163,8 +163,8 @@ class SensitivityAnalysis: ...@@ -163,8 +163,8 @@ class SensitivityAnalysis:
if val_kwargs is None: if val_kwargs is None:
val_kwargs = {} val_kwargs = {}
# Get the original validation metric(accuracy/loss) before pruning # Get the original validation metric(accuracy/loss) before pruning
if self.ori_metric is None: # Get the accuracy baseline before starting the analysis.
self.ori_metric = self.val_func(*val_args, **val_kwargs) self.ori_metric = self.val_func(*val_args, **val_kwargs)
namelist = list(self.target_layer.keys()) namelist = list(self.target_layer.keys())
if specified_layers is not None: if specified_layers is not None:
# only analyze several specified conv layers # only analyze several specified conv layers
...@@ -172,19 +172,21 @@ class SensitivityAnalysis: ...@@ -172,19 +172,21 @@ class SensitivityAnalysis:
for name in namelist: for name in namelist:
self.sensitivities[name] = {} self.sensitivities[name] = {}
for sparsity in self.sparsities: for sparsity in self.sparsities:
# here the sparsity is the relative sparsity of the
# the remained weights
# Calculate the actual prune ratio based on the already pruned ratio # Calculate the actual prune ratio based on the already pruned ratio
sparsity = ( real_sparsity = (
1.0 - self.already_pruned[name]) * sparsity + self.already_pruned[name] 1.0 - self.already_pruned[name]) * sparsity + self.already_pruned[name]
# TODO In current L1/L2 Filter Pruner, the 'op_types' is still necessary # TODO In current L1/L2 Filter Pruner, the 'op_types' is still necessary
# I think the L1/L2 Pruner should specify the op_types automaticlly # I think the L1/L2 Pruner should specify the op_types automaticlly
# according to the op_names # according to the op_names
cfg = [{'sparsity': sparsity, 'op_names': [ cfg = [{'sparsity': real_sparsity, 'op_names': [
name], 'op_types': ['Conv2d']}] name], 'op_types': ['Conv2d']}]
pruner = self.Pruner(self.model, cfg) pruner = self.Pruner(self.model, cfg)
pruner.compress() pruner.compress()
val_metric = self.val_func(*val_args, **val_kwargs) val_metric = self.val_func(*val_args, **val_kwargs)
logger.info('Layer: %s Sparsity: %.2f Validation Metric: %.4f', logger.info('Layer: %s Sparsity: %.2f Validation Metric: %.4f',
name, sparsity, val_metric) name, real_sparsity, val_metric)
self.sensitivities[name][sparsity] = val_metric self.sensitivities[name][sparsity] = val_metric
pruner._unwrap_model() pruner._unwrap_model()
......
...@@ -15,7 +15,7 @@ def query_nb101_trial_stats(arch, num_epochs, isomorphism=True, reduction=None, ...@@ -15,7 +15,7 @@ def query_nb101_trial_stats(arch, num_epochs, isomorphism=True, reduction=None,
arch : dict or None arch : dict or None
If a dict, it is in the format that is described in If a dict, it is in the format that is described in
:class:`nni.nas.benchmark.nasbench101.Nb101TrialConfig`. Only trial stats :class:`nni.nas.benchmark.nasbench101.Nb101TrialConfig`. Only trial stats
matched will be returned. If none, architecture will be a wildcard. matched will be returned. If none, all architectures in the database will be matched.
num_epochs : int or None num_epochs : int or None
If int, matching results will be returned. Otherwise a wildcard. If int, matching results will be returned. Otherwise a wildcard.
isomorphism : boolean isomorphism : boolean
......
...@@ -14,7 +14,7 @@ def query_nb201_trial_stats(arch, num_epochs, dataset, reduction=None, include_i ...@@ -14,7 +14,7 @@ def query_nb201_trial_stats(arch, num_epochs, dataset, reduction=None, include_i
arch : dict or None arch : dict or None
If a dict, it is in the format that is described in If a dict, it is in the format that is described in
:class:`nni.nas.benchmark.nasbench201.Nb201TrialConfig`. Only trial stats :class:`nni.nas.benchmark.nasbench201.Nb201TrialConfig`. Only trial stats
matched will be returned. If none, architecture will be a wildcard. matched will be returned. If none, all architectures in the database will be matched.
num_epochs : int or None num_epochs : int or None
If int, matching results will be returned. Otherwise a wildcard. If int, matching results will be returned. Otherwise a wildcard.
dataset : str or None dataset : str or None
......
...@@ -162,7 +162,7 @@ class Mutator(BaseMutator): ...@@ -162,7 +162,7 @@ class Mutator(BaseMutator):
if self._connect_all: if self._connect_all:
return self._all_connect_tensor_reduction(mutable.reduction, return self._all_connect_tensor_reduction(mutable.reduction,
[op(*args, **kwargs) for op in mutable]), \ [op(*args, **kwargs) for op in mutable]), \
torch.ones(len(mutable)) torch.ones(len(mutable)).bool()
def _map_fn(op, args, kwargs): def _map_fn(op, args, kwargs):
return op(*args, **kwargs) return op(*args, **kwargs)
...@@ -192,7 +192,7 @@ class Mutator(BaseMutator): ...@@ -192,7 +192,7 @@ class Mutator(BaseMutator):
""" """
if self._connect_all: if self._connect_all:
return self._all_connect_tensor_reduction(mutable.reduction, tensor_list), \ return self._all_connect_tensor_reduction(mutable.reduction, tensor_list), \
torch.ones(mutable.n_candidates) torch.ones(mutable.n_candidates).bool()
mask = self._get_decision(mutable) mask = self._get_decision(mutable)
assert len(mask) == mutable.n_candidates, \ assert len(mask) == mutable.n_candidates, \
"Invalid mask, expected {} to be of length {}.".format(mask, mutable.n_candidates) "Invalid mask, expected {} to be of length {}.".format(mask, mutable.n_candidates)
......
...@@ -79,7 +79,6 @@ class ENASMicroLayer(nn.Module): ...@@ -79,7 +79,6 @@ class ENASMicroLayer(nn.Module):
""" """
def __init__(self, num_nodes, in_channels_pp, in_channels_p, out_channels, reduction): def __init__(self, num_nodes, in_channels_pp, in_channels_p, out_channels, reduction):
super().__init__() super().__init__()
print(in_channels_pp, in_channels_p, out_channels, reduction)
self.reduction = reduction self.reduction = reduction
if self.reduction: if self.reduction:
self.reduce0 = FactorizedReduce(in_channels_pp, out_channels, affine=False) self.reduce0 = FactorizedReduce(in_channels_pp, out_channels, affine=False)
...@@ -110,7 +109,7 @@ class ENASMicroLayer(nn.Module): ...@@ -110,7 +109,7 @@ class ENASMicroLayer(nn.Module):
pprev: torch.Tensor pprev: torch.Tensor
the output of the previous previous layer the output of the previous previous layer
prev: torch.Tensor prev: torch.Tensor
the output of the previous previous layer the output of the previous layer
""" """
if self.reduction: if self.reduction:
pprev, prev = self.reduce0(pprev), self.reduce1(prev) pprev, prev = self.reduce0(pprev), self.reduce1(prev)
...@@ -160,7 +159,7 @@ class ENASMacroLayer(mutables.MutableScope): ...@@ -160,7 +159,7 @@ class ENASMacroLayer(mutables.MutableScope):
PoolBranch('avg', in_filters, out_filters, 3, 1, 1), PoolBranch('avg', in_filters, out_filters, 3, 1, 1),
PoolBranch('max', in_filters, out_filters, 3, 1, 1) PoolBranch('max', in_filters, out_filters, 3, 1, 1)
]) ])
if prev_labels > 0: if prev_labels:
self.skipconnect = mutables.InputChoice(choose_from=prev_labels, n_chosen=None) self.skipconnect = mutables.InputChoice(choose_from=prev_labels, n_chosen=None)
else: else:
self.skipconnect = None self.skipconnect = None
......
...@@ -286,32 +286,9 @@ def create_customized_class_instance(class_params): ...@@ -286,32 +286,9 @@ def create_customized_class_instance(class_params):
return instance return instance
def get_python_dir(sitepackages_path):
if sys.platform == "win32":
return str(Path(sitepackages_path))
else:
return str(Path(sitepackages_path).parents[2])
def get_nni_installation_parent_dir(): def get_nni_installation_parent_dir():
''' Find nni installation parent directory ''' Find nni installation parent directory
''' '''
def try_installation_path_sequentially(*sitepackages):
'''Try different installation path sequentially util nni is found.
Return None if nothing is found
'''
def _generate_installation_path(sitepackages_path):
python_dir = get_python_dir(sitepackages_path)
entry_file = os.path.join(python_dir, 'nni', 'main.js')
if os.path.isfile(entry_file):
return python_dir
return None
for sitepackage in sitepackages:
python_dir = _generate_installation_path(sitepackage)
if python_dir:
return python_dir
return None
if os.getenv('VIRTUAL_ENV'): if os.getenv('VIRTUAL_ENV'):
# if 'virtualenv' package is used, `site` has not attr getsitepackages, so we will instead use VIRTUAL_ENV # if 'virtualenv' package is used, `site` has not attr getsitepackages, so we will instead use VIRTUAL_ENV
# Note that conda venv will not have VIRTUAL_ENV # Note that conda venv will not have VIRTUAL_ENV
...@@ -321,12 +298,23 @@ def get_nni_installation_parent_dir(): ...@@ -321,12 +298,23 @@ def get_nni_installation_parent_dir():
# If system-wide python is used, we will give priority to using `local sitepackage`--"usersitepackages()" given # If system-wide python is used, we will give priority to using `local sitepackage`--"usersitepackages()" given
# that nni exists there # that nni exists there
if python_sitepackage.startswith('/usr') or python_sitepackage.startswith('/Library'): if python_sitepackage.startswith('/usr') or python_sitepackage.startswith('/Library'):
python_dir = try_installation_path_sequentially(site.getusersitepackages(), site.getsitepackages()[0]) python_dir = _try_installation_path_sequentially(site.getusersitepackages(), *site.getsitepackages())
else: else:
python_dir = try_installation_path_sequentially(site.getsitepackages()[0], site.getusersitepackages()) python_dir = _try_installation_path_sequentially(*site.getsitepackages(), site.getusersitepackages())
return python_dir return python_dir
def _try_installation_path_sequentially(*sitepackages):
'''Try different installation path sequentially util nni is found.
Return None if nothing is found
'''
for sitepackage in sitepackages:
path = Path(sitepackage)
if len(path.parents) > 2 and (path.parents[2] / 'nni' / 'main.js').is_file():
return str(path.parents[2])
if (path / 'nni' / 'main.js').is_file():
return str(path)
return None
def get_nni_installation_path(): def get_nni_installation_path():
''' Find nni installation directory ''' Find nni installation directory
''' '''
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import unittest
import numpy as np
import tensorflow as tf
####
#
# This file tests pruners on 2 models: a classic CNN model, and a naive model with one linear layer
#
# The CNN model is used to test layer detecting and instrumenting.
#
# The naive model is used to test mask calculation.
# It has a single 10x10 linear layer without bias, and `reduce_sum` its result.
# To help predicting pruning result, the linear layer has fixed initial weights:
# [ [ 0.0, 1.0, 2.0, ..., 9.0 ], [0.1, 1.1, 2.1, ..., 9.1 ], ... , [0.9, 1.0, 2.9, ..., 9.9 ] ]
#
####
# This tensor is used as input of 10x10 linear layer, the first dimension is batch size
tensor1x10 = tf.constant([[1.0] * 10])
@unittest.skipIf(tf.__version__[0] != '2', 'Skip TF 1.x setup')
class TfCompressorTestCase(unittest.TestCase):
def test_layer_detection(self):
# Conv and dense layers should be compressed, pool and flatten should not.
# This also tests instrumenting functionality.
self._test_layer_detection_on_model(CnnModel())
def _test_layer_detection_on_model(self, model):
pruner = pruners['level'](model)
pruner.compress()
layer_types = sorted(wrapper.layer_info.type for wrapper in pruner.wrappers)
assert layer_types == ['Conv2D', 'Dense', 'Dense'], layer_types
def test_level_pruner(self):
# prune 90% : 9.0 + 9.1 + ... + 9.9 = 94.5
model = build_naive_model()
pruners['level'](model).compress()
x = model(tensor1x10)
assert x.numpy() == 94.5
try:
from tensorflow.keras import Model, Sequential
from tensorflow.keras.layers import (Conv2D, Dense, Flatten, MaxPool2D)
from nni.compression.tensorflow import LevelPruner
pruners = {
'level': (lambda model: LevelPruner(model, [{'sparsity': 0.9, 'op_types': ['default']}])),
}
class CnnModel(Model):
def __init__(self):
super().__init__()
self.conv = Conv2D(filters=10, kernel_size=3, activation='relu')
self.pool = MaxPool2D(pool_size=2)
self.flatten = Flatten()
self.fc1 = Dense(units=10, activation='relu')
self.fc2 = Dense(units=5, activation='softmax')
def call(self, x):
x = self.conv(x)
x = self.pool(x)
x = self.flatten(x)
x = self.fc1(x)
x = self.fc2(x)
return x
class NaiveModel(Model):
def __init__(self):
super().__init__()
self.fc = Dense(units=10, use_bias=False)
def call(self, x):
return tf.math.reduce_sum(self.fc(x))
except Exception:
pass
def build_naive_model():
model = NaiveModel()
model.build(tensor1x10.shape)
weight = [[(i + j * 0.1) for i in range(10)] for j in range(10)]
model.set_weights([np.array(weight)])
return model
if __name__ == '__main__':
unittest.main()
...@@ -145,7 +145,7 @@ class SpeedupTestCase(TestCase): ...@@ -145,7 +145,7 @@ class SpeedupTestCase(TestCase):
assert model.backbone2.fc1.in_features == int(orig_model.backbone2.fc1.in_features * SPARSITY) assert model.backbone2.fc1.in_features == int(orig_model.backbone2.fc1.in_features * SPARSITY)
def test_speedup_integration(self): def test_speedup_integration(self):
for model_name in ['resnet18', 'squeezenet1_1', 'mobilenet_v2', 'densenet121', 'inception_v3']: for model_name in ['resnet18', 'squeezenet1_1', 'mobilenet_v2', 'densenet121', 'densenet169', 'inception_v3']:
Model = getattr(models, model_name) Model = getattr(models, model_name)
net = Model(pretrained=True, progress=False).to(device) net = Model(pretrained=True, progress=False).to(device)
speedup_model = Model().to(device) speedup_model = Model().to(device)
......
...@@ -85,8 +85,10 @@ class Compare extends React.Component<CompareProps, {}> { ...@@ -85,8 +85,10 @@ class Compare extends React.Component<CompareProps, {}> {
containLabel: true containLabel: true
}, },
legend: { legend: {
// more than 10 trials will hide legend type: 'scroll',
data: idsList.length > 10 ? null : idsList right: 40,
left: idsList.length > 6 ? 80 : null,
data: idsList
}, },
xAxis: { xAxis: {
type: 'category', type: 'category',
...@@ -135,8 +137,17 @@ class Compare extends React.Component<CompareProps, {}> { ...@@ -135,8 +137,17 @@ class Compare extends React.Component<CompareProps, {}> {
isComplexSearchSpace = (typeof parameterList[0][parameterKeys[0]] === 'object') isComplexSearchSpace = (typeof parameterList[0][parameterKeys[0]] === 'object')
? true : false; ? true : false;
} }
const width = this.getWebUIWidth();
let scrollClass;
if (width > 1200) {
scrollClass = idList.length > 3 ? 'flex' : '';
} else if (width < 700) {
scrollClass = idList.length > 1 ? 'flex' : '';
} else {
scrollClass = idList.length > 2 ? 'flex' : '';
}
return ( return (
<table className="compare-modal-table"> <table className={`compare-modal-table ${scrollClass}`}>
<tbody> <tbody>
<tr> <tr>
<td className="column">Id</td> <td className="column">Id</td>
...@@ -200,6 +211,10 @@ class Compare extends React.Component<CompareProps, {}> { ...@@ -200,6 +211,10 @@ class Compare extends React.Component<CompareProps, {}> {
); );
} }
getWebUIWidth = (): number => {
return window.innerWidth;
}
componentDidMount(): void { componentDidMount(): void {
this._isCompareMount = true; this._isCompareMount = true;
} }
......
import * as React from 'react'; import * as React from 'react';
import { DOWNLOAD_IP } from '../../static/const'; import { DOWNLOAD_IP } from '../../static/const';
import LogPathChild from './LogPathChild';
interface PaiTrialChildProps { interface PaiTrialChildProps {
logString: string; logString: string;
...@@ -21,7 +22,7 @@ class PaiTrialChild extends React.Component<PaiTrialChildProps, {}> { ...@@ -21,7 +22,7 @@ class PaiTrialChild extends React.Component<PaiTrialChildProps, {}> {
{ {
logString === '' logString === ''
? ?
<div /> null
: :
<div> <div>
{ {
...@@ -33,10 +34,13 @@ class PaiTrialChild extends React.Component<PaiTrialChildProps, {}> { ...@@ -33,10 +34,13 @@ class PaiTrialChild extends React.Component<PaiTrialChildProps, {}> {
href={`${DOWNLOAD_IP}/trial_${id}.log`} href={`${DOWNLOAD_IP}/trial_${id}.log`}
style={{ marginRight: 10 }} style={{ marginRight: 10 }}
> >
trial stdout Trial stdout
</a> </a>
: :
<span>trial stdout: {logString}</span> <LogPathChild
eachLogpath={logString}
logName="Trial stdout:"
/>
} }
</div> </div>
} }
......
...@@ -42,7 +42,7 @@ class PaitrialLog extends React.Component<PaitrialLogProps, {}> { ...@@ -42,7 +42,7 @@ class PaitrialLog extends React.Component<PaitrialLogProps, {}> {
> >
Trial stdout Trial stdout
</a> </a>
<a target="_blank" rel="noopener noreferrer" href={logStr.split(',')[1]}>hdfsLog</a> <a target="_blank" rel="noopener noreferrer" href={logStr.split(',')[1]}>NFS log</a>
</div> </div>
: :
<div> <div>
...@@ -52,7 +52,7 @@ class PaitrialLog extends React.Component<PaitrialLogProps, {}> { ...@@ -52,7 +52,7 @@ class PaitrialLog extends React.Component<PaitrialLogProps, {}> {
/> />
<LogPathChild <LogPathChild
eachLogpath={logStr.split(',')[1]} eachLogpath={logStr.split(',')[1]}
logName="Log on HDFS:" logName="Log on NFS:"
/> />
</div> </div>
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment