Unverified Commit cd3a912a authored by SparkSnail's avatar SparkSnail Committed by GitHub
Browse files

Merge pull request #218 from microsoft/master

merge master
parents a0846f2a e9cba778
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from unittest import TestCase, main
import numpy as np
import tensorflow as tf
import torch
import torch.nn.functional as F
import nni.compression.torch as torch_compressor
import math
if tf.__version__ >= '2.0':
import nni.compression.tensorflow as tf_compressor
def get_tf_model():
model = tf.keras.models.Sequential([
tf.keras.layers.Conv2D(filters=5, kernel_size=7, input_shape=[28, 28, 1], activation='relu', padding="SAME"),
......@@ -24,39 +29,66 @@ def get_tf_model():
metrics=["accuracy"])
return model
class TorchModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv1 = torch.nn.Conv2d(1, 5, 5, 1)
self.bn1 = torch.nn.BatchNorm2d(5)
self.conv2 = torch.nn.Conv2d(5, 10, 5, 1)
self.bn2 = torch.nn.BatchNorm2d(10)
self.fc1 = torch.nn.Linear(4 * 4 * 10, 100)
self.fc2 = torch.nn.Linear(100, 10)
def forward(self, x):
x = F.relu(self.conv1(x))
x = F.relu(self.bn1(self.conv1(x)))
x = F.max_pool2d(x, 2, 2)
x = F.relu(self.conv2(x))
x = F.relu(self.bn2(self.conv2(x)))
x = F.max_pool2d(x, 2, 2)
x = x.view(-1, 4 * 4 * 10)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return F.log_softmax(x, dim=1)
def tf2(func):
def test_tf2_func(*args):
if tf.__version__ >= '2.0':
func(*args)
return test_tf2_func
k1 = [[1]*3]*3
k2 = [[2]*3]*3
k3 = [[3]*3]*3
k4 = [[4]*3]*3
k5 = [[5]*3]*3
# for fpgm filter pruner test
w = np.array([[[[i+1]*3]*3]*5 for i in range(10)])
w = [[k1, k2, k3, k4, k5]] * 10
class CompressorTestCase(TestCase):
def test_torch_quantizer_modules_detection(self):
# test if modules can be detected
model = TorchModel()
config_list = [{
'quant_types': ['weight'],
'quant_bits': 8,
'op_types':['Conv2d', 'Linear']
}, {
'quant_types': ['output'],
'quant_bits': 8,
'quant_start_step': 0,
'op_types':['ReLU']
}]
model.relu = torch.nn.ReLU()
quantizer = torch_compressor.QAT_Quantizer(model, config_list)
quantizer.compress()
modules_to_compress = quantizer.get_modules_to_compress()
modules_to_compress_name = [ t[0].name for t in modules_to_compress]
assert "conv1" in modules_to_compress_name
assert "conv2" in modules_to_compress_name
assert "fc1" in modules_to_compress_name
assert "fc2" in modules_to_compress_name
assert "relu" in modules_to_compress_name
assert len(modules_to_compress_name) == 5
def test_torch_level_pruner(self):
model = TorchModel()
configure_list = [{'sparsity': 0.8, 'op_types': ['default']}]
......@@ -74,7 +106,7 @@ class CompressorTestCase(TestCase):
'quant_bits': {
'weight': 8,
},
'op_types':['Conv2d', 'Linear']
'op_types': ['Conv2d', 'Linear']
}]
torch_compressor.NaiveQuantizer(model, configure_list).compress()
......@@ -84,16 +116,16 @@ class CompressorTestCase(TestCase):
def test_torch_fpgm_pruner(self):
"""
With filters(kernels) defined as above (k1 - k5), it is obvious that k3 is the Geometric Median
With filters(kernels) weights defined as above (w), it is obvious that w[4] and w[5] is the Geometric Median
which minimize the total geometric distance by defination of Geometric Median in this paper:
Filter Pruning via Geometric Median for Deep Convolutional Neural Networks Acceleration,
https://arxiv.org/pdf/1811.00250.pdf
So if sparsity is 0.2, the expected masks should mask out all k3, this can be verified through:
`all(torch.sum(masks, (0, 2, 3)).numpy() == np.array([90., 90., 0., 90., 90.]))`
So if sparsity is 0.2, the expected masks should mask out w[4] and w[5], this can be verified through:
`all(torch.sum(masks, (1, 2, 3)).numpy() == np.array([45., 45., 45., 45., 0., 0., 45., 45., 45., 45.]))`
If sparsity is 0.6, the expected masks should mask out all k2, k3, k4, this can be verified through:
`all(torch.sum(masks, (0, 2, 3)).numpy() == np.array([90., 0., 0., 0., 90.]))`
If sparsity is 0.6, the expected masks should mask out w[2] - w[7], this can be verified through:
`all(torch.sum(masks, (1, 2, 3)).numpy() == np.array([45., 45., 0., 0., 0., 0., 0., 0., 45., 45.]))`
"""
model = TorchModel()
......@@ -103,12 +135,12 @@ class CompressorTestCase(TestCase):
model.conv2.weight.data = torch.tensor(w).float()
layer = torch_compressor.compressor.LayerInfo('conv2', model.conv2)
masks = pruner.calc_mask(layer, config_list[0])
assert all(torch.sum(masks, (0, 2, 3)).numpy() == np.array([90., 90., 0., 90., 90.]))
assert all(torch.sum(masks, (1, 2, 3)).numpy() == np.array([45., 45., 45., 45., 0., 0., 45., 45., 45., 45.]))
pruner.update_epoch(1)
model.conv2.weight.data = torch.tensor(w).float()
masks = pruner.calc_mask(layer, config_list[1])
assert all(torch.sum(masks, (0, 2, 3)).numpy() == np.array([90., 0., 0., 0., 90.]))
assert all(torch.sum(masks, (1, 2, 3)).numpy() == np.array([45., 45., 0., 0., 0., 0., 0., 0., 45., 45.]))
@tf2
def test_tf_fpgm_pruner(self):
......@@ -122,17 +154,122 @@ class CompressorTestCase(TestCase):
layer = tf_compressor.compressor.LayerInfo(model.layers[2])
masks = pruner.calc_mask(layer, config_list[0]).numpy()
masks = masks.transpose([2, 3, 0, 1]).transpose([1, 0, 2, 3])
masks = masks.reshape((-1, masks.shape[-1])).transpose([1, 0])
assert all(masks.sum((0, 2, 3)) == np.array([90., 90., 0., 90., 90.]))
assert all(masks.sum((1)) == np.array([45., 45., 45., 45., 0., 0., 45., 45., 45., 45.]))
pruner.update_epoch(1)
model.layers[2].set_weights([weights[0], weights[1].numpy()])
masks = pruner.calc_mask(layer, config_list[1]).numpy()
masks = masks.transpose([2, 3, 0, 1]).transpose([1, 0, 2, 3])
masks = masks.reshape((-1, masks.shape[-1])).transpose([1, 0])
assert all(masks.sum((1)) == np.array([45., 45., 0., 0., 0., 0., 0., 0., 45., 45.]))
def test_torch_l1filter_pruner(self):
"""
Filters with the minimum sum of the weights' L1 norm are pruned in this paper:
PRUNING FILTERS FOR EFFICIENT CONVNETS,
https://arxiv.org/abs/1608.08710
So if sparsity is 0.2, the expected masks should mask out filter 0, this can be verified through:
`all(torch.sum(mask1, (1, 2, 3)).numpy() == np.array([0., 27., 27., 27., 27.]))`
If sparsity is 0.6, the expected masks should mask out filter 0,1,2, this can be verified through:
`all(torch.sum(mask2, (1, 2, 3)).numpy() == np.array([0., 0., 0., 27., 27.]))`
"""
w = np.array([np.zeros((3, 3, 3)), np.ones((3, 3, 3)), np.ones((3, 3, 3)) * 2,
np.ones((3, 3, 3)) * 3, np.ones((3, 3, 3)) * 4])
model = TorchModel()
config_list = [{'sparsity': 0.2, 'op_names': ['conv1']}, {'sparsity': 0.6, 'op_names': ['conv2']}]
pruner = torch_compressor.L1FilterPruner(model, config_list)
model.conv1.weight.data = torch.tensor(w).float()
model.conv2.weight.data = torch.tensor(w).float()
layer1 = torch_compressor.compressor.LayerInfo('conv1', model.conv1)
mask1 = pruner.calc_mask(layer1, config_list[0])
layer2 = torch_compressor.compressor.LayerInfo('conv2', model.conv2)
mask2 = pruner.calc_mask(layer2, config_list[1])
assert all(torch.sum(mask1, (1, 2, 3)).numpy() == np.array([0., 27., 27., 27., 27.]))
assert all(torch.sum(mask2, (1, 2, 3)).numpy() == np.array([0., 0., 0., 27., 27.]))
def test_torch_slim_pruner(self):
"""
Scale factors with minimum l1 norm in the BN layers are pruned in this paper:
Learning Efficient Convolutional Networks through Network Slimming,
https://arxiv.org/pdf/1708.06519.pdf
So if sparsity is 0.2, the expected masks should mask out channel 0, this can be verified through:
`all(mask1.numpy() == np.array([0., 1., 1., 1., 1.]))`
`all(mask2.numpy() == np.array([0., 1., 1., 1., 1.]))`
If sparsity is 0.6, the expected masks should mask out channel 0,1,2, this can be verified through:
`all(mask1.numpy() == np.array([0., 0., 0., 1., 1.]))`
`all(mask2.numpy() == np.array([0., 0., 0., 1., 1.]))`
"""
w = np.array([0, 1, 2, 3, 4])
model = TorchModel()
config_list = [{'sparsity': 0.2, 'op_types': ['BatchNorm2d']}]
model.bn1.weight.data = torch.tensor(w).float()
model.bn2.weight.data = torch.tensor(-w).float()
pruner = torch_compressor.SlimPruner(model, config_list)
layer1 = torch_compressor.compressor.LayerInfo('bn1', model.bn1)
mask1 = pruner.calc_mask(layer1, config_list[0])
layer2 = torch_compressor.compressor.LayerInfo('bn2', model.bn2)
mask2 = pruner.calc_mask(layer2, config_list[0])
assert all(mask1.numpy() == np.array([0., 1., 1., 1., 1.]))
assert all(mask2.numpy() == np.array([0., 1., 1., 1., 1.]))
config_list = [{'sparsity': 0.6, 'op_types': ['BatchNorm2d']}]
model.bn1.weight.data = torch.tensor(w).float()
model.bn2.weight.data = torch.tensor(w).float()
pruner = torch_compressor.SlimPruner(model, config_list)
layer1 = torch_compressor.compressor.LayerInfo('bn1', model.bn1)
mask1 = pruner.calc_mask(layer1, config_list[0])
layer2 = torch_compressor.compressor.LayerInfo('bn2', model.bn2)
mask2 = pruner.calc_mask(layer2, config_list[0])
assert all(mask1.numpy() == np.array([0., 0., 0., 1., 1.]))
assert all(mask2.numpy() == np.array([0., 0., 0., 1., 1.]))
def test_torch_QAT_quantizer(self):
model = TorchModel()
config_list = [{
'quant_types': ['weight'],
'quant_bits': 8,
'op_types':['Conv2d', 'Linear']
}, {
'quant_types': ['output'],
'quant_bits': 8,
'quant_start_step': 0,
'op_types':['ReLU']
}]
model.relu = torch.nn.ReLU()
quantizer = torch_compressor.QAT_Quantizer(model, config_list)
quantizer.compress()
# test quantize
# range not including 0
eps = 1e-7
weight = torch.tensor([[1, 2], [3, 5]]).float()
quantize_weight = quantizer.quantize_weight(weight, config_list[0], model.conv2)
assert math.isclose(model.conv2.scale, 5 / 255, abs_tol=eps)
assert model.conv2.zero_point == 0
# range including 0
weight = torch.tensor([[-1, 2], [3, 5]]).float()
quantize_weight = quantizer.quantize_weight(weight, config_list[0], model.conv2)
assert math.isclose(model.conv2.scale, 6 / 255, abs_tol=eps)
assert model.conv2.zero_point in (42, 43)
assert all(masks.sum((0, 2, 3)) == np.array([90., 0., 0., 0., 90.]))
# test ema
x = torch.tensor([[-0.2, 0], [0.1, 0.2]])
out = model.relu(x)
assert math.isclose(model.relu.tracked_min_biased, 0, abs_tol=eps)
assert math.isclose(model.relu.tracked_max_biased, 0.002, abs_tol=eps)
quantizer.step()
x = torch.tensor([[0.2, 0.4], [0.6, 0.8]])
out = model.relu(x)
assert math.isclose(model.relu.tracked_min_biased, 0.002, abs_tol=eps)
assert math.isclose(model.relu.tracked_max_biased, 0.00998, abs_tol=eps)
if __name__ == '__main__':
main()
# Copyright (c) Microsoft Corporation. All rights reserved.
#
# MIT License
#
# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and
# associated documentation files (the "Software"), to deal in the Software without restriction,
# including without limitation the rights to use, copy, modify, merge, publish, distribute,
# sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all copies or
# substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT
# NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT
# OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
# ==================================================================================================
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import json
from io import BytesIO
......
# Copyright (c) Microsoft Corporation. All rights reserved.
#
# MIT License
#
# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and
# associated documentation files (the "Software"), to deal in the Software without restriction,
# including without limitation the rights to use, copy, modify, merge, publish, distribute,
# sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all copies or
# substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT
# NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT
# OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
# ==================================================================================================
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import nni.protocol
from nni.protocol import CommandType, send, receive
......
# Copyright (c) Microsoft Corporation. All rights reserved.
#
# MIT License
#
# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and
# associated documentation files (the "Software"), to deal in the Software without restriction,
# including without limitation the rights to use, copy, modify, merge, publish, distribute,
# sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all copies or
# substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT
# NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT
# OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
# ==================================================================================================
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import os
......
# Copyright (c) Microsoft Corporation. All rights reserved.
#
# MIT License
#
# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and
# associated documentation files (the "Software"), to deal in the Software without restriction,
# including without limitation the rights to use, copy, modify, merge, publish, distribute,
# sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all copies or
# substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT
# NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT
# OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
# ==================================================================================================
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import nni
import nni.platform.test as test_platform
......
# Copyright (c) Microsoft Corporation. All rights reserved.
#
# MIT License
#
# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and
# associated documentation files (the "Software"), to deal in the Software without restriction,
# including without limitation the rights to use, copy, modify, merge, publish, distribute,
# sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all copies or
# substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT
# NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT
# OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
# ==================================================================================================
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from unittest import TestCase, main
......
......@@ -12,7 +12,7 @@
"copy-to-clipboard": "^3.0.8",
"css-loader": "0.28.7",
"dotenv": "^8.0.0",
"echarts": "^4.1.0",
"echarts": "^4.5.0",
"echarts-for-react": "^2.0.14",
"file-loader": "^4.1.0",
"fork-ts-checker-webpack-plugin": "^1.5.0",
......
......@@ -3,7 +3,7 @@ import { Switch } from 'antd';
import ReactEcharts from 'echarts-for-react';
import { EXPERIMENT, TRIALS } from '../../static/datamodel';
import { Trial } from '../../static/model/trial';
import { TooltipForAccuracy } from '../../static/interface';
import { TooltipForAccuracy, EventMap } from '../../static/interface';
require('echarts/lib/chart/scatter');
require('echarts/lib/component/tooltip');
require('echarts/lib/component/title');
......@@ -16,12 +16,18 @@ interface DefaultPointProps {
interface DefaultPointState {
bestCurveEnabled: boolean;
startY: number; // dataZoomY
endY: number;
}
class DefaultPoint extends React.Component<DefaultPointProps, DefaultPointState> {
constructor(props: DefaultPointProps) {
super(props);
this.state = { bestCurveEnabled: false };
this.state = {
bestCurveEnabled: false,
startY: 0, // dataZoomY
endY: 100,
};
}
loadDefault = (checked: boolean) => {
......@@ -35,6 +41,7 @@ class DefaultPoint extends React.Component<DefaultPointProps, DefaultPointState>
render() {
const graph = this.generateGraph();
const accNodata = (graph === EmptyGraph ? 'No data' : '');
const onEvents = { 'dataZoom': this.metricDataZoom };
return (
<div>
......@@ -53,6 +60,7 @@ class DefaultPoint extends React.Component<DefaultPointProps, DefaultPointState>
}}
theme="my_theme"
notMerge={true} // update now
onEvents={onEvents}
/>
<div className="showMess">{accNodata}</div>
</div>
......@@ -64,31 +72,17 @@ class DefaultPoint extends React.Component<DefaultPointProps, DefaultPointState>
if (trials.length === 0) {
return EmptyGraph;
}
const graph = generateGraphConfig(trials[trials.length - 1].sequenceId);
const graph = this.generateGraphConfig(trials[trials.length - 1].sequenceId);
if (this.state.bestCurveEnabled) {
(graph as any).series = [ generateBestCurveSeries(trials), generateScatterSeries(trials) ];
(graph as any).series = [generateBestCurveSeries(trials), generateScatterSeries(trials)];
} else {
(graph as any).series = [ generateScatterSeries(trials) ];
(graph as any).series = [generateScatterSeries(trials)];
}
return graph;
}
}
const EmptyGraph = {
grid: {
left: '8%'
},
xAxis: {
name: 'Trial',
type: 'category',
},
yAxis: {
name: 'Default metric',
type: 'value',
}
};
function generateGraphConfig(maxSequenceId: number) {
private generateGraphConfig(maxSequenceId: number) {
const { startY, endY } = this.state;
return {
grid: {
left: '8%',
......@@ -97,7 +91,7 @@ function generateGraphConfig(maxSequenceId: number) {
trigger: 'item',
enterable: true,
position: (point: Array<number>, data: TooltipForAccuracy) => (
[ (data.data[0] < maxSequenceId ? point[0] : (point[0] - 300)), 80 ]
[(data.data[0] < maxSequenceId ? point[0] : (point[0] - 300)), 80]
),
formatter: (data: TooltipForAccuracy) => (
'<div class="tooldetailAccuracy">' +
......@@ -107,6 +101,16 @@ function generateGraphConfig(maxSequenceId: number) {
'</div>'
),
},
dataZoom: [
{
id: 'dataZoomY',
type: 'inside',
yAxisIndex: [0],
filterMode: 'empty',
start: startY,
end: endY
}
],
xAxis: {
name: 'Trial',
type: 'category',
......@@ -118,8 +122,33 @@ function generateGraphConfig(maxSequenceId: number) {
},
series: undefined,
};
}
private metricDataZoom = (e: EventMap) => {
if (e.batch !== undefined) {
this.setState(() => ({
startY: (e.batch[0].start !== null ? e.batch[0].start : 0),
endY: (e.batch[0].end !== null ? e.batch[0].end : 100)
}));
}
}
}
const EmptyGraph = {
grid: {
left: '8%'
},
xAxis: {
name: 'Trial',
type: 'category',
},
yAxis: {
name: 'Default metric',
type: 'value',
scale: true,
}
};
function generateScatterSeries(trials: Trial[]) {
const data = trials.map(trial => [
trial.sequenceId,
......@@ -135,17 +164,17 @@ function generateScatterSeries(trials: Trial[]) {
function generateBestCurveSeries(trials: Trial[]) {
let best = trials[0];
const data = [[ best.sequenceId, best.accuracy, best.description.parameters ]];
const data = [[best.sequenceId, best.accuracy, best.description.parameters]];
for (let i = 1; i < trials.length; i++) {
const trial = trials[i];
const delta = trial.accuracy! - best.accuracy!;
const better = (EXPERIMENT.optimizeMode === 'minimize') ? (delta < 0) : (delta > 0);
if (better) {
data.push([ trial.sequenceId, trial.accuracy, trial.description.parameters ]);
data.push([trial.sequenceId, trial.accuracy, trial.description.parameters]);
best = trial;
} else {
data.push([ trial.sequenceId, best.accuracy, trial.description.parameters ]);
data.push([trial.sequenceId, best.accuracy, trial.description.parameters]);
}
}
......
import * as React from 'react';
import ReactEcharts from 'echarts-for-react';
import { TableObj } from 'src/static/interface';
import { TableObj, EventMap } from 'src/static/interface';
import { filterDuration } from 'src/static/function';
require('echarts/lib/chart/bar');
require('echarts/lib/component/tooltip');
......@@ -17,7 +17,8 @@ interface DurationProps {
}
interface DurationState {
durationSource: {};
startDuration: number; // for record data zoom
endDuration: number;
}
class Duration extends React.Component<DurationProps, DurationState> {
......@@ -26,63 +27,13 @@ class Duration extends React.Component<DurationProps, DurationState> {
super(props);
this.state = {
durationSource: this.initDuration(this.props.source),
};
}
initDuration = (source: Array<TableObj>) => {
const trialId: Array<string> = [];
const trialTime: Array<number> = [];
const trialJobs = source.filter(filterDuration);
Object.keys(trialJobs).map(item => {
const temp = trialJobs[item];
trialId.push(temp.sequenceId);
trialTime.push(temp.duration);
});
return {
tooltip: {
trigger: 'axis',
axisPointer: {
type: 'shadow'
}
},
grid: {
bottom: '3%',
containLabel: true,
left: '1%',
right: '4%'
},
dataZoom: [{
type: 'slider',
name: 'trial',
filterMode: 'filter',
yAxisIndex: 0,
orient: 'vertical'
}, {
type: 'slider',
name: 'trial',
filterMode: 'filter',
xAxisIndex: 0
}],
xAxis: {
name: 'Time',
type: 'value',
},
yAxis: {
name: 'Trial',
type: 'category',
data: trialId
},
series: [{
type: 'bar',
data: trialTime
}]
startDuration: 0, // for record data zoom
endDuration: 100,
};
}
getOption = (dataObj: Runtrial) => {
const { startDuration, endDuration } = this.state;
return {
tooltip: {
trigger: 'axis',
......@@ -96,19 +47,16 @@ class Duration extends React.Component<DurationProps, DurationState> {
left: '1%',
right: '4%'
},
dataZoom: [{
type: 'slider',
name: 'trial',
filterMode: 'filter',
yAxisIndex: 0,
orient: 'vertical'
}, {
type: 'slider',
name: 'trial',
filterMode: 'filter',
xAxisIndex: 0
}],
dataZoom: [
{
id: 'dataZoomY',
type: 'inside',
yAxisIndex: [0],
filterMode: 'empty',
start: startDuration,
end: endDuration
},
],
xAxis: {
name: 'Time',
type: 'value',
......@@ -140,21 +88,7 @@ class Duration extends React.Component<DurationProps, DurationState> {
trialId: trialId,
trialTime: trialTime
});
this.setState({
durationSource: this.getOption(trialRun[0])
});
}
componentDidMount() {
const { source } = this.props;
this.drawDurationGraph(source);
}
componentWillReceiveProps(nextProps: DurationProps) {
const { whichGraph, source } = nextProps;
if (whichGraph === '3') {
this.drawDurationGraph(source);
}
return this.getOption(trialRun[0]);
}
shouldComponentUpdate(nextProps: DurationProps, nextState: DurationState) {
......@@ -183,18 +117,31 @@ class Duration extends React.Component<DurationProps, DurationState> {
}
render() {
const { durationSource } = this.state;
const { source } = this.props;
const graph = this.drawDurationGraph(source);
const onEvents = { 'dataZoom': this.durationDataZoom };
return (
<div>
<ReactEcharts
option={durationSource}
option={graph}
style={{ width: '95%', height: 412, margin: '0 auto' }}
theme="my_theme"
notMerge={true} // update now
onEvents={onEvents}
/>
</div>
);
}
private durationDataZoom = (e: EventMap) => {
if (e.batch !== undefined) {
this.setState(() => ({
startDuration: (e.batch[0].start !== null ? e.batch[0].start : 0),
endDuration: (e.batch[0].end !== null ? e.batch[0].end : 100)
}));
}
}
}
export default Duration;
import * as React from 'react';
import { Row, Button, Switch } from 'antd';
import { TooltipForIntermediate, TableObj, Intermedia } from '../../static/interface';
import { TooltipForIntermediate, TableObj, Intermedia, EventMap } from '../../static/interface';
import ReactEcharts from 'echarts-for-react';
require('echarts/lib/component/tooltip');
require('echarts/lib/component/title');
......@@ -14,6 +14,8 @@ interface IntermediateState {
isFilter: boolean;
length: number;
clickCounts: number; // user filter intermediate click confirm btn's counts
startMediaY: number;
endMediaY: number;
}
interface IntermediateProps {
......@@ -38,7 +40,9 @@ class Intermediate extends React.Component<IntermediateProps, IntermediateState>
isLoadconfirmBtn: false,
isFilter: false,
length: 100000,
clickCounts: 0
clickCounts: 0,
startMediaY: 0,
endMediaY: 100
};
}
......@@ -48,6 +52,7 @@ class Intermediate extends React.Component<IntermediateProps, IntermediateState>
length: source.length,
detailSource: source
});
const { startMediaY, endMediaY } = this.state;
const trialIntermediate: Array<Intermedia> = [];
Object.keys(source).map(item => {
const temp = source[item];
......@@ -113,6 +118,16 @@ class Intermediate extends React.Component<IntermediateProps, IntermediateState>
type: 'value',
name: 'Metric'
},
dataZoom: [
{
id: 'dataZoomY',
type: 'inside',
yAxisIndex: [0],
filterMode: 'empty',
start: startMediaY,
end: endMediaY
}
],
series: trialIntermediate
};
this.setState({
......@@ -258,6 +273,7 @@ class Intermediate extends React.Component<IntermediateProps, IntermediateState>
render() {
const { interSource, isLoadconfirmBtn, isFilter } = this.state;
const IntermediateEvents = { 'dataZoom': this.intermediateDataZoom };
return (
<div>
{/* style in para.scss */}
......@@ -265,7 +281,7 @@ class Intermediate extends React.Component<IntermediateProps, IntermediateState>
{
isFilter
?
<span style={{marginRight: 15}}>
<span style={{ marginRight: 15 }}>
<span className="filter-x"># Intermediate result</span>
<input
// placeholder="point"
......@@ -306,12 +322,22 @@ class Intermediate extends React.Component<IntermediateProps, IntermediateState>
option={interSource}
style={{ width: '100%', height: 418, margin: '0 auto' }}
notMerge={true} // update now
onEvents={IntermediateEvents}
/>
<div className="yAxis"># Intermediate result</div>
</Row>
</div>
);
}
private intermediateDataZoom = (e: EventMap) => {
if (e.batch !== undefined) {
this.setState(() => ({
startMediaY: (e.batch[0].start !== null ? e.batch[0].start : 0),
endMediaY: (e.batch[0].end !== null ? e.batch[0].end : 100)
}));
}
}
}
export default Intermediate;
......@@ -275,6 +275,7 @@ class Para extends React.Component<ParaProps, ParaState> {
parallelAxis.push({
dim: i,
name: 'default metric',
scale: true,
nameTextStyle: {
fontWeight: 700
}
......
......@@ -586,17 +586,16 @@ const AccuracyColumnConfig: ColumnProps<TableRecord> = {
dataIndex: 'accuracy',
width: 120,
sorter: (a, b, sortOrder) => {
if (a.accuracy === undefined) {
return sortOrder === 'ascend' ? -1 : 1;
} else if (b.accuracy === undefined) {
if (a.latestAccuracy === undefined) {
return sortOrder === 'ascend' ? 1 : -1;
} else if (b.latestAccuracy === undefined) {
return sortOrder === 'ascend' ? -1 : 1;
} else {
return a.accuracy - b.accuracy;
return a.latestAccuracy - b.latestAccuracy;
}
},
render: (text, record) => (
// TODO: is this needed?
<div>{record.latestAccuracy}</div>
<div>{record.formattedLatestAccuracy}</div>
)
};
......
......@@ -186,5 +186,5 @@ function formatAccuracy(accuracy: number): string {
export {
convertTime, convertDuration, getFinalResult, getFinal, downFile,
intermediateGraphOption, killJob, filterByStatus, filterDuration,
formatAccuracy, formatTimestamp, metricAccuracy,
formatAccuracy, formatTimestamp, metricAccuracy
};
......@@ -24,7 +24,8 @@ interface TableRecord {
status: string;
intermediateCount: number;
accuracy?: number;
latestAccuracy: string; // formatted string
latestAccuracy: number | undefined;
formattedLatestAccuracy: string; // format (LATEST/FINAL)
}
interface SearchSpace {
......@@ -81,6 +82,7 @@ interface Dimobj {
axisLabel?: object;
axisLine?: object;
nameTextStyle?: object;
scale?: boolean;
}
interface ParaObj {
......@@ -179,9 +181,13 @@ interface NNIManagerStatus {
errors: string[];
}
interface EventMap {
[key: string]: () => void;
}
export {
TableObj, TableRecord, Parameters, ExperimentProfile, AccurPoint,
DetailAccurPoint, TooltipForAccuracy, ParaObj, Dimobj, FinalType,
TooltipForIntermediate, SearchSpace, Intermedia, MetricDataRecord, TrialJobInfo,
NNIManagerStatus,
NNIManagerStatus, EventMap
};
......@@ -46,6 +46,22 @@ class Trial implements TableObj {
return this.metricsInitialized && this.finalAcc !== undefined && !isNaN(this.finalAcc);
}
get latestAccuracy(): number | undefined {
if (this.accuracy !== undefined) {
return this.accuracy;
} else if (this.intermediates.length > 0) {
// TODO: support intermeidate result is dict
const temp = this.intermediates[this.intermediates.length - 1];
if (temp !== undefined) {
return JSON.parse(temp.data);
} else {
return undefined;
}
} else {
return undefined;
}
}
/* table obj start */
get tableRecord(): TableRecord {
......@@ -62,7 +78,8 @@ class Trial implements TableObj {
status: this.info.status,
intermediateCount: this.intermediates.length,
accuracy: this.finalAcc,
latestAccuracy: this.formatLatestAccuracy(),
latestAccuracy: this.latestAccuracy,
formattedLatestAccuracy: this.formatLatestAccuracy(),
};
}
......
This source diff could not be displayed because it is too large. You can view the blob instead.
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
"""
Test code for weight sharing
need NFS setup and mounted as `/mnt/nfs/nni`
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
"""
SimpleTuner for Weight Sharing
"""
......
# Copyright (c) Microsoft Corporation
# All rights reserved.
#
# MIT License
#
# Permission is hereby granted, free of charge,
# to any person obtaining a copy of this software and associated
# documentation files (the "Software"), to deal in the Software without restriction,
# including without limitation the rights to use, copy, modify, merge, publish,
# distribute, sublicense, and/or sell copies of the Software, and
# to permit persons to whom the Software is furnished to do so, subject to the following conditions:
# The above copyright notice and this permission notice shall be included
# in all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING
# BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import sys
import time
......@@ -26,7 +9,7 @@ from utils import GREEN, RED, CLEAR, setup_experiment
def test_nni_cli():
import nnicli as nc
config_file = 'config_test/examples/mnist.test.yml'
config_file = 'config_test/examples/mnist-tfv1.test.yml'
try:
# Sleep here to make sure previous stopped exp has enough time to exit to avoid port conflict
......
# Copyright (c) Microsoft Corporation
# All rights reserved.
#
# MIT License
#
# Permission is hereby granted, free of charge,
# to any person obtaining a copy of this software and associated
# documentation files (the "Software"), to deal in the Software without restriction,
# including without limitation the rights to use, copy, modify, merge, publish,
# distribute, sublicense, and/or sell copies of the Software, and
# to permit persons to whom the Software is furnished to do so, subject to the following conditions:
# The above copyright notice and this permission notice shall be included
# in all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING
# BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import os
import argparse
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment