"...lm-evaluation-harness.git" did not exist on "692e0f83b5341b543fa288f84289617f793e4e93"
Unverified Commit f9ee589c authored by SparkSnail's avatar SparkSnail Committed by GitHub
Browse files

Merge pull request #222 from microsoft/master

merge master
parents 36e6e350 4f3ee9cb
import os
import sys
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import json
import logging
import os
import sys
import torch
import nni
from nni.env_vars import trial_env_vars
from nni.nas.pytorch.base_mutator import BaseMutator
from nni.nas.pytorch.mutables import LayerChoice, InputChoice
from nni.nas.pytorch.mutator import Mutator
logger = logging.getLogger(__name__)
NNI_GEN_SEARCH_SPACE = "NNI_GEN_SEARCH_SPACE"
LAYER_CHOICE = "layer_choice"
INPUT_CHOICE = "input_choice"
def get_and_apply_next_architecture(model):
"""
Wrapper of ClassicMutator to make it more meaningful,
similar to ```get_next_parameter``` for HPO.
Parameters
----------
model : pytorch model
......@@ -22,12 +31,14 @@ def get_and_apply_next_architecture(model):
"""
ClassicMutator(model)
class ClassicMutator(BaseMutator):
class ClassicMutator(Mutator):
"""
This mutator is to apply the architecture chosen from tuner.
It implements the forward function of LayerChoice and InputChoice,
to only activate the chosen ones
"""
def __init__(self, model):
"""
Generate search space based on ```model```.
......@@ -37,70 +48,131 @@ class ClassicMutator(BaseMutator):
use ```nnictl``` to start an experiment. The other is standalone mode
where users directly run the trial command, this mode chooses the first
one(s) for each LayerChoice and InputChoice.
Parameters
----------
model : pytorch model
model : PyTorch model
user's model with search space (e.g., LayerChoice, InputChoice) embedded in it
"""
super(ClassicMutator, self).__init__(model)
self.chosen_arch = {}
self.search_space = self._generate_search_space()
if 'NNI_GEN_SEARCH_SPACE' in os.environ:
self._chosen_arch = {}
self._search_space = self._generate_search_space()
if NNI_GEN_SEARCH_SPACE in os.environ:
# dry run for only generating search space
self._dump_search_space(self.search_space, os.environ.get('NNI_GEN_SEARCH_SPACE'))
self._dump_search_space(os.environ[NNI_GEN_SEARCH_SPACE])
sys.exit(0)
# get chosen arch from tuner
self.chosen_arch = nni.get_next_parameter()
if not self.chosen_arch and trial_env_vars.NNI_PLATFORM is None:
logger.warning('This is in standalone mode, the chosen are the first one(s)')
self.chosen_arch = self._standalone_generate_chosen()
self._validate_chosen_arch()
def _validate_chosen_arch(self):
pass
if trial_env_vars.NNI_PLATFORM is None:
logger.warning("This is in standalone mode, the chosen are the first one(s).")
self._chosen_arch = self._standalone_generate_chosen()
else:
# get chosen arch from tuner
self._chosen_arch = nni.get_next_parameter()
self.reset()
def _standalone_generate_chosen(self):
def _sample_layer_choice(self, mutable, idx, value, search_space_item):
"""
Generate the chosen architecture for standalone mode,
i.e., choose the first one(s) for LayerChoice and InputChoice
Convert layer choice to tensor representation.
{ key_name: {'_value': "conv1",
'_idx': 0} }
Parameters
----------
mutable : Mutable
idx : int
Number `idx` of list will be selected.
value : str
The verbose representation of the selected value.
search_space_item : list
The list for corresponding search space.
"""
# doesn't support multihot for layer choice yet
onehot_list = [False] * mutable.length
assert 0 <= idx < mutable.length and search_space_item[idx] == value, \
"Index '{}' in search space '{}' is not '{}'".format(idx, search_space_item, value)
onehot_list[idx] = True
return torch.tensor(onehot_list, dtype=torch.bool) # pylint: disable=not-callable
def _sample_input_choice(self, mutable, idx, value, search_space_item):
"""
Convert input choice to tensor representation.
{ key_name: {'_value': ["in1"],
'_idx': [0]} }
Parameters
----------
mutable : Mutable
idx : int
Number `idx` of list will be selected.
value : str
The verbose representation of the selected value.
search_space_item : list
The list for corresponding search space.
"""
multihot_list = [False] * mutable.n_candidates
for i, v in zip(idx, value):
assert 0 <= i < mutable.n_candidates and search_space_item[i] == v, \
"Index '{}' in search space '{}' is not '{}'".format(i, search_space_item, v)
assert not multihot_list[i], "'{}' is selected twice in '{}', which is not allowed.".format(i, idx)
multihot_list[i] = True
return torch.tensor(multihot_list, dtype=torch.bool) # pylint: disable=not-callable
def sample_search(self):
return self.sample_final()
def sample_final(self):
assert set(self._chosen_arch.keys()) == set(self._search_space.keys()), \
"Unmatched keys, expected keys '{}' from search space, found '{}'.".format(self._search_space.keys(),
self._chosen_arch.keys())
result = dict()
for mutable in self.mutables:
assert mutable.key in self._chosen_arch, "Expected '{}' in chosen arch, but not found.".format(mutable.key)
data = self._chosen_arch[mutable.key]
assert isinstance(data, dict) and "_value" in data and "_idx" in data, \
"'{}' is not a valid choice.".format(data)
value = data["_value"]
idx = data["_idx"]
search_space_item = self._search_space[mutable.key]["_value"]
if isinstance(mutable, LayerChoice):
result[mutable.key] = self._sample_layer_choice(mutable, idx, value, search_space_item)
elif isinstance(mutable, InputChoice):
result[mutable.key] = self._sample_input_choice(mutable, idx, value, search_space_item)
else:
raise TypeError("Unsupported mutable type: '%s'." % type(mutable))
return result
def _standalone_generate_chosen(self):
"""
Generate the chosen architecture for standalone mode,
i.e., choose the first one(s) for LayerChoice and InputChoice.
::
{ key_name: {"_value": "conv1",
"_idx": 0} }
{ key_name: {"_value": ["in1"],
"_idx": [0]} }
Returns
-------
dict
the chosen architecture
"""
chosen_arch = {}
for key, val in self.search_space.items():
if val['_type'] == 'layer_choice':
choices = val['_value']
chosen_arch[key] = {'_value': choices[0], '_idx': 0}
elif val['_type'] == 'input_choice':
choices = val['_value']['candidates']
n_chosen = val['_value']['n_chosen']
chosen_arch[key] = {'_value': choices[:n_chosen], '_idx': list(range(n_chosen))}
for key, val in self._search_space.items():
if val["_type"] == LAYER_CHOICE:
choices = val["_value"]
chosen_arch[key] = {"_value": choices[0], "_idx": 0}
elif val["_type"] == INPUT_CHOICE:
choices = val["_value"]["candidates"]
n_chosen = val["_value"]["n_chosen"]
chosen_arch[key] = {"_value": choices[:n_chosen], "_idx": list(range(n_chosen))}
else:
raise ValueError('Unknown key %s and value %s' % (key, val))
raise ValueError("Unknown key '%s' and value '%s'." % (key, val))
return chosen_arch
def _generate_search_space(self):
"""
Generate search space from mutables.
Here is the search space format:
{ key_name: {'_type': 'layer_choice',
'_value': ["conv1", "conv2"]} }
{ key_name: {'_type': 'input_choice',
'_value': {'candidates': ["in1", "in2"],
'n_chosen': 1}} }
::
{ key_name: {"_type": "layer_choice",
"_value": ["conv1", "conv2"]} }
{ key_name: {"_type": "input_choice",
"_value": {"candidates": ["in1", "in2"],
"n_chosen": 1}} }
Returns
-------
dict
......@@ -112,81 +184,16 @@ class ClassicMutator(BaseMutator):
if isinstance(mutable, LayerChoice):
key = mutable.key
val = [repr(choice) for choice in mutable.choices]
search_space[key] = {"_type": "layer_choice", "_value": val}
search_space[key] = {"_type": LAYER_CHOICE, "_value": val}
elif isinstance(mutable, InputChoice):
key = mutable.key
search_space[key] = {"_type": "input_choice",
search_space[key] = {"_type": INPUT_CHOICE,
"_value": {"candidates": mutable.choose_from,
"n_chosen": mutable.n_chosen}}
else:
raise TypeError('Unsupported mutable type: %s.' % type(mutable))
raise TypeError("Unsupported mutable type: '%s'." % type(mutable))
return search_space
def _dump_search_space(self, search_space, file_path):
with open(file_path, 'w') as ss_file:
json.dump(search_space, ss_file)
def _tensor_reduction(self, reduction_type, tensor_list):
if tensor_list == "none":
return tensor_list
if not tensor_list:
return None # empty. return None for now
if len(tensor_list) == 1:
return tensor_list[0]
if reduction_type == "sum":
return sum(tensor_list)
if reduction_type == "mean":
return sum(tensor_list) / len(tensor_list)
if reduction_type == "concat":
return torch.cat(tensor_list, dim=1)
raise ValueError("Unrecognized reduction policy: \"{}\"".format(reduction_type))
def on_forward_layer_choice(self, mutable, *inputs):
"""
Implement the forward of LayerChoice
Parameters
----------
mutable: LayerChoice
inputs: list of torch.Tensor
Returns
-------
tuple
return of the chosen op, the index of the chosen op
"""
assert mutable.key in self.chosen_arch
val = self.chosen_arch[mutable.key]
assert isinstance(val, dict)
idx = val['_idx']
assert self.search_space[mutable.key]['_value'][idx] == val['_value']
return mutable.choices[idx](*inputs), idx
def on_forward_input_choice(self, mutable, tensor_list):
"""
Implement the forward of InputChoice
Parameters
----------
mutable: InputChoice
tensor_list: list of torch.Tensor
tags: list of string
Returns
-------
tuple of torch.Tensor and list
reduced tensor, mask list
"""
assert mutable.key in self.chosen_arch
val = self.chosen_arch[mutable.key]
assert isinstance(val, dict)
mask = [0 for _ in range(mutable.n_candidates)]
out = []
for i, idx in enumerate(val['_idx']):
# check whether idx matches the chosen candidate name
assert self.search_space[mutable.key]['_value']['candidates'][idx] == val['_value'][i]
out.append(tensor_list[idx])
mask[idx] = 1
return self._tensor_reduction(mutable.reduction, out), mask
def _dump_search_space(self, file_path):
with open(file_path, "w") as ss_file:
json.dump(self._search_space, ss_file, sort_keys=True, indent=2)
......@@ -41,7 +41,8 @@ class Mutator(BaseMutator):
def reset(self):
"""
Reset the mutator by call the `sample_search` to resample (for search).
Reset the mutator by call the `sample_search` to resample (for search). Stores the result in a local
variable so that `on_forward_layer_choice` and `on_forward_input_choice` can use the decision directly.
Returns
-------
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import logging
from collections import OrderedDict
_counter = 0
_logger = logging.getLogger(__name__)
def global_mutable_counting():
global _counter
......@@ -23,6 +26,12 @@ class AverageMeterGroup:
self.meters[k] = AverageMeter(k, ":4f")
self.meters[k].update(v)
def __getattr__(self, item):
return self.meters[item]
def __getitem__(self, item):
return self.meters[item]
def __str__(self):
return " ".join(str(v) for _, v in self.meters.items())
......@@ -52,6 +61,8 @@ class AverageMeter:
self.count = 0
def update(self, val, n=1):
if not isinstance(val, float) and not isinstance(val, int):
_logger.warning("Values passed to AverageMeter must be number, not %s.", type(val))
self.val = val
self.sum += val * n
self.count += n
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from ..env_vars import trial_env_vars
from ..env_vars import trial_env_vars, dispatcher_env_vars
assert dispatcher_env_vars.SDK_PROCESS != 'dispatcher'
if trial_env_vars.NNI_PLATFORM is None:
from .standalone import *
elif trial_env_vars.NNI_PLATFORM == 'unittest':
from .test import *
elif trial_env_vars.NNI_PLATFORM in ('local', 'remote', 'pai', 'kubeflow', 'frameworkcontroller'):
elif trial_env_vars.NNI_PLATFORM in ('local', 'remote', 'pai', 'kubeflow', 'frameworkcontroller', 'paiYarn'):
from .local import *
else:
raise RuntimeError('Unknown platform %s' % trial_env_vars.NNI_PLATFORM)
......@@ -18,26 +18,31 @@ from smac.utils.io.cmd_reader import CMDReader
from ConfigSpaceNNI import Configuration
import nni
from nni.tuner import Tuner
from nni.utils import OptimizeMode, extract_scalar_reward
from .convert_ss_to_scenario import generate_scenario
logger = logging.getLogger('smac_AutoML')
class SMACTuner(Tuner):
"""
This is a wrapper of [SMAC](https://github.com/automl/SMAC3) following NNI tuner interface.
It only supports ``SMAC`` mode, and does not support the multiple instances of SMAC3 (i.e.,
the same configuration is run multiple times).
"""
def __init__(self, optimize_mode="maximize"):
def __init__(self, optimize_mode="maximize", config_dedup=False):
"""
Parameters
----------
optimize_mode : str
Optimize mode, 'maximize' or 'minimize', by default 'maximize'
config_dedup : bool
If True, the tuner will not generate a configuration that has been already generated.
If False, a configuration may be generated twice, but it is rare for relatively large search space.
"""
self.logger = logging.getLogger(
self.__module__ + "." + self.__class__.__name__)
self.logger = logger
self.optimize_mode = OptimizeMode(optimize_mode)
self.total_data = {}
self.optimizer = None
......@@ -47,6 +52,7 @@ class SMACTuner(Tuner):
self.loguniform_key = set()
self.categorical_dict = {}
self.cs = None
self.dedup = config_dedup
def _main_cli(self):
"""
......@@ -127,7 +133,7 @@ class SMACTuner(Tuner):
search_space : dict
The format could be referred to search space spec (https://nni.readthedocs.io/en/latest/Tutorial/SearchSpaceSpec.html).
"""
self.logger.info('update search space in SMAC.')
if not self.update_ss_done:
self.categorical_dict = generate_scenario(search_space)
if self.categorical_dict is None:
......@@ -225,9 +231,19 @@ class SMACTuner(Tuner):
return self.param_postprocess(init_challenger.get_dictionary())
else:
challengers = self.smbo_solver.nni_smac_request_challengers()
challengers_empty = True
for challenger in challengers:
challengers_empty = False
if self.dedup:
match = [v for k, v in self.total_data.items() \
if v.get_dictionary() == challenger.get_dictionary()]
if match:
continue
self.total_data[parameter_id] = challenger
return self.param_postprocess(challenger.get_dictionary())
assert challengers_empty is False, 'The case that challengers is empty is not handled.'
self.logger.info('In generate_parameters: No more new parameters.')
raise nni.NoMoreTrialError('No more new parameters.')
def generate_multiple_parameters(self, parameter_id_list, **kwargs):
"""
......@@ -261,9 +277,16 @@ class SMACTuner(Tuner):
for challenger in challengers:
if cnt >= len(parameter_id_list):
break
if self.dedup:
match = [v for k, v in self.total_data.items() \
if v.get_dictionary() == challenger.get_dictionary()]
if match:
continue
self.total_data[parameter_id_list[cnt]] = challenger
params.append(self.param_postprocess(challenger.get_dictionary()))
cnt += 1
if self.dedup and not params:
self.logger.info('In generate_multiple_parameters: No more new parameters.')
return params
def import_data(self, data):
......
......@@ -23,7 +23,8 @@
"@typescript-eslint/consistent-type-assertions": 0,
"@typescript-eslint/no-inferrable-types": 0,
"no-inner-declarations": 0,
"@typescript-eslint/no-var-requires": 0
"@typescript-eslint/no-var-requires": 0,
"react/display-name": 0
},
"ignorePatterns": [
"node_modules/",
......
......@@ -27,7 +27,7 @@ class App extends React.Component<{}, AppState> {
};
}
async componentDidMount() {
async componentDidMount(): Promise<void> {
await Promise.all([ EXPERIMENT.init(), TRIALS.init() ]);
this.setState(state => ({ experimentUpdateBroadcast: state.experimentUpdateBroadcast + 1 }));
this.setState(state => ({ trialsUpdateBroadcast: state.trialsUpdateBroadcast + 1 }));
......@@ -35,7 +35,7 @@ class App extends React.Component<{}, AppState> {
this.setState({ metricGraphMode: (EXPERIMENT.optimizeMode === 'minimize' ? 'min' : 'max') });
}
changeInterval = (interval: number) => {
changeInterval = (interval: number): void => {
this.setState({ interval });
if (this.timerId === null && interval !== 0) {
window.setTimeout(this.refresh);
......@@ -45,15 +45,15 @@ class App extends React.Component<{}, AppState> {
}
// TODO: use local storage
changeColumn = (columnList: Array<string>) => {
changeColumn = (columnList: Array<string>): void => {
this.setState({ columnList: columnList });
}
changeMetricGraphMode = (val: 'max' | 'min') => {
changeMetricGraphMode = (val: 'max' | 'min'): void => {
this.setState({ metricGraphMode: val });
}
render() {
render(): React.ReactNode{
const { interval, columnList, experimentUpdateBroadcast, trialsUpdateBroadcast, metricGraphMode } = this.state;
if (experimentUpdateBroadcast === 0 || trialsUpdateBroadcast === 0) {
return null; // TODO: render a loading page
......@@ -86,7 +86,7 @@ class App extends React.Component<{}, AppState> {
);
}
private refresh = async () => {
private refresh = async (): Promise<void> => {
const [ experimentUpdated, trialsUpdated ] = await Promise.all([ EXPERIMENT.update(), TRIALS.update() ]);
if (experimentUpdated) {
this.setState(state => ({ experimentUpdateBroadcast: state.experimentUpdateBroadcast + 1 }));
......@@ -107,7 +107,7 @@ class App extends React.Component<{}, AppState> {
}
}
private async lastRefresh() {
private async lastRefresh(): Promise<void> {
await EXPERIMENT.update();
await TRIALS.update(true);
this.setState(state => ({ experimentUpdateBroadcast: state.experimentUpdateBroadcast + 1 }));
......
......@@ -20,7 +20,7 @@ class Compare extends React.Component<CompareProps, {}> {
super(props);
}
intermediate = () => {
intermediate = (): any => {
const { compareRows } = this.props;
const trialIntermediate: Array<Intermedia> = [];
const idsList: Array<string> = [];
......@@ -40,7 +40,7 @@ class Compare extends React.Component<CompareProps, {}> {
const legend: Array<string> = [];
// max length
const length = trialIntermediate[0] !== undefined ? trialIntermediate[0].data.length : 0;
const xAxis: Array<number> = [];
const xAxis: number[] = [];
Object.keys(trialIntermediate).map(item => {
const temp = trialIntermediate[item];
legend.push(temp.name);
......@@ -52,14 +52,14 @@ class Compare extends React.Component<CompareProps, {}> {
tooltip: {
trigger: 'item',
enterable: true,
position: function (point: Array<number>, data: TooltipForIntermediate) {
position: function (point: number[], data: TooltipForIntermediate): number[] {
if (data.dataIndex < length / 2) {
return [point[0], 80];
} else {
return [point[0] - 300, 80];
}
},
formatter: function (data: TooltipForIntermediate) {
formatter: function (data: TooltipForIntermediate): any {
const trialId = data.seriesName;
let obj = {};
const temp = trialIntermediate.find(key => key.name === trialId);
......@@ -106,10 +106,10 @@ class Compare extends React.Component<CompareProps, {}> {
}
// render table column ---
initColumn = () => {
initColumn = (): React.ReactNode => {
const idList: Array<string> = [];
const sequenceIdList: Array<number> = [];
const durationList: Array<number> = [];
const sequenceIdList: number[] = [];
const durationList: number[] = [];
const compareRows = this.props.compareRows.map(tableRecord => TRIALS.getTrial(tableRecord.id));
......@@ -195,15 +195,15 @@ class Compare extends React.Component<CompareProps, {}> {
);
}
componentDidMount() {
componentDidMount(): void {
this._isCompareMount = true;
}
componentWillUnmount() {
componentWillUnmount(): void {
this._isCompareMount = false;
}
render() {
render(): React.ReactNode{
const { visible, cancelFunc } = this.props;
return (
......
......@@ -39,7 +39,7 @@ class Customize extends React.Component<CustomizeProps, CustomizeState> {
}
// [submit click] user add a new trial [submit a trial]
addNewTrial = () => {
addNewTrial = (): void => {
const { searchSpace, copyTrialParameter } = this.state;
// get user edited hyperParameter, ps: will change data type if you modify the input val
const customized = this.props.form.getFieldsValue();
......@@ -76,19 +76,19 @@ class Customize extends React.Component<CustomizeProps, CustomizeState> {
}
warningConfirm = () => {
warningConfirm = (): void => {
this.setState(() => ({ isShowWarning: false }));
const { customParameters } = this.state;
this.submitCustomize(customParameters);
}
warningCancel = () => {
warningCancel = (): void => {
this.setState(() => ({ isShowWarning: false }));
}
submitCustomize = (customized: Object) => {
submitCustomize = (customized: Record<string, any>): void => {
// delete `tag` key
for (let i in customized) {
for (const i in customized) {
if (i === 'tag') {
delete customized[i];
}
......@@ -106,24 +106,24 @@ class Customize extends React.Component<CustomizeProps, CustomizeState> {
this.setState(() => ({ isShowSubmitFailed: true }));
}
})
.catch(error => {
.catch(() => {
this.setState(() => ({ isShowSubmitFailed: true }));
});
}
closeSucceedHint = () => {
closeSucceedHint = (): void => {
// also close customized trial modal
this.setState(() => ({ isShowSubmitSucceed: false }));
this.props.closeCustomizeModal();
}
closeFailedHint = () => {
closeFailedHint = (): void => {
// also close customized trial modal
this.setState(() => ({ isShowSubmitFailed: false }));
this.props.closeCustomizeModal();
}
componentDidMount() {
componentDidMount(): void {
const { copyTrialId } = this.props;
if (copyTrialId !== undefined && TRIALS.getTrial(copyTrialId) !== undefined) {
const originCopyTrialPara = TRIALS.getTrial(copyTrialId).description.parameters;
......@@ -131,7 +131,7 @@ class Customize extends React.Component<CustomizeProps, CustomizeState> {
}
}
componentWillReceiveProps(nextProps: CustomizeProps) {
componentWillReceiveProps(nextProps: CustomizeProps): void {
const { copyTrialId } = nextProps;
if (copyTrialId !== undefined && TRIALS.getTrial(copyTrialId) !== undefined) {
const originCopyTrialPara = TRIALS.getTrial(copyTrialId).description.parameters;
......@@ -139,7 +139,7 @@ class Customize extends React.Component<CustomizeProps, CustomizeState> {
}
}
render() {
render(): React.ReactNode {
const { closeCustomizeModal, visible } = this.props;
const { isShowSubmitSucceed, isShowSubmitFailed, isShowWarning, customID, copyTrialParameter } = this.state;
const {
......
......@@ -29,7 +29,7 @@ class ExperimentDrawer extends React.Component<ExpDrawerProps, ExpDrawerState> {
};
}
getExperimentContent = () => {
getExperimentContent = (): void => {
axios
.all([
axios.get(`${MANAGER_IP}/experiment`),
......@@ -41,7 +41,7 @@ class ExperimentDrawer extends React.Component<ExpDrawerProps, ExpDrawerState> {
if (res.data.params.searchSpace) {
res.data.params.searchSpace = JSON.parse(res.data.params.searchSpace);
}
let trialMessagesArr = res1.data;
const trialMessagesArr = res1.data;
const interResultList = res2.data;
Object.keys(trialMessagesArr).map(item => {
// not deal with trial's hyperParameters
......@@ -66,34 +66,34 @@ class ExperimentDrawer extends React.Component<ExpDrawerProps, ExpDrawerState> {
}));
}
downExperimentParameters = () => {
downExperimentParameters = (): void => {
const { experiment } = this.state;
downFile(experiment, 'experiment.json');
}
onWindowResize = () => {
onWindowResize = (): void => {
this.setState(() => ({expDrawerHeight: window.innerHeight - 48}));
}
componentDidMount() {
componentDidMount(): void {
this._isCompareMount = true;
this.getExperimentContent();
window.addEventListener('resize', this.onWindowResize);
}
componentWillReceiveProps(nextProps: ExpDrawerProps) {
componentWillReceiveProps(nextProps: ExpDrawerProps): void {
const { isVisble } = nextProps;
if (isVisble === true) {
this.getExperimentContent();
}
}
componentWillUnmount() {
componentWillUnmount(): void {
this._isCompareMount = false;
window.removeEventListener('resize', this.onWindowResize);
}
render() {
render(): React.ReactNode {
const { isVisble, closeExpDrawer } = this.props;
const { experiment, expDrawerHeight } = this.state;
return (
......
......@@ -33,19 +33,19 @@ class LogDrawer extends React.Component<LogDrawerProps, LogDrawerState> {
};
}
downloadNNImanager = () => {
downloadNNImanager = (): void => {
if (this.state.nniManagerLogStr !== null) {
downFile(this.state.nniManagerLogStr, 'nnimanager.log');
}
}
downloadDispatcher = () => {
downloadDispatcher = (): void => {
if (this.state.dispatcherLogStr !== null) {
downFile(this.state.dispatcherLogStr, 'dispatcher.log');
}
}
dispatcherHTML = () => {
dispatcherHTML = (): any => {
return (
<div>
<span>Dispatcher Log</span>
......@@ -56,7 +56,7 @@ class LogDrawer extends React.Component<LogDrawerProps, LogDrawerState> {
);
}
nnimanagerHTML = () => {
nnimanagerHTML = (): any => {
return (
<div>
<span>NNImanager Log</span>
......@@ -65,21 +65,21 @@ class LogDrawer extends React.Component<LogDrawerProps, LogDrawerState> {
);
}
setLogDrawerHeight = () => {
setLogDrawerHeight = (): void => {
this.setState(() => ({ logDrawerHeight: window.innerHeight - 48 }));
}
async componentDidMount() {
async componentDidMount(): Promise<void> {
this.refresh();
window.addEventListener('resize', this.setLogDrawerHeight);
}
componentWillUnmount() {
componentWillUnmount(): void {
window.clearTimeout(this.timerId);
window.removeEventListener('resize', this.setLogDrawerHeight);
}
render() {
render(): React.ReactNode {
const { closeDrawer, activeTab } = this.props;
const { nniManagerLogStr, dispatcherLogStr, isLoading, logDrawerHeight } = this.state;
return (
......@@ -164,7 +164,7 @@ class LogDrawer extends React.Component<LogDrawerProps, LogDrawerState> {
);
}
private refresh = () => {
private refresh = (): void => {
window.clearTimeout(this.timerId);
const dispatcherPromise = axios.get(`${DOWNLOAD_IP}/dispatcher.log`);
const nniManagerPromise = axios.get(`${DOWNLOAD_IP}/nnimanager.log`);
......@@ -184,7 +184,7 @@ class LogDrawer extends React.Component<LogDrawerProps, LogDrawerState> {
});
}
private manualRefresh = () => {
private manualRefresh = (): void => {
this.setState({ isLoading: true });
this.refresh();
}
......
......@@ -35,30 +35,31 @@ class Overview extends React.Component<OverviewProps, OverviewState> {
};
}
clickMaxTop = (event: React.SyntheticEvent<EventTarget>) => {
clickMaxTop = (event: React.SyntheticEvent<EventTarget>): void => {
event.stopPropagation();
// #999 panel active bgcolor; #b3b3b3 as usual
const { changeMetricGraphMode } = this.props;
changeMetricGraphMode('max');
}
clickMinTop = (event: React.SyntheticEvent<EventTarget>) => {
clickMinTop = (event: React.SyntheticEvent<EventTarget>): void => {
event.stopPropagation();
const { changeMetricGraphMode } = this.props;
changeMetricGraphMode('min');
}
changeConcurrency = (val: number) => {
changeConcurrency = (val: number): void => {
this.setState({ trialConcurrency: val });
}
render() {
render(): React.ReactNode {
const { trialConcurrency } = this.state;
const { experimentUpdateBroadcast, metricGraphMode } = this.props;
const searchSpace = this.convertSearchSpace();
const bestTrials = this.findBestTrials();
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
const bestAccuracy = bestTrials.length > 0 ? bestTrials[0].accuracy! : NaN;
const accuracyGraphData = this.generateAccuracyGraph(bestTrials);
const noDataMessage = bestTrials.length > 0 ? '' : 'No data';
......@@ -147,7 +148,7 @@ class Overview extends React.Component<OverviewProps, OverviewState> {
const searchSpace = Object.assign({}, EXPERIMENT.searchSpace);
Object.keys(searchSpace).map(item => {
const key = searchSpace[item]._type;
let value = searchSpace[item]._value;
const value = searchSpace[item]._value;
switch (key) {
case 'quniform':
case 'qnormal':
......@@ -161,7 +162,7 @@ class Overview extends React.Component<OverviewProps, OverviewState> {
}
private findBestTrials(): Trial[] {
let bestTrials = TRIALS.sort();
const bestTrials = TRIALS.sort();
if (this.props.metricGraphMode === 'max') {
bestTrials.reverse().splice(10);
} else {
......
......@@ -50,7 +50,7 @@ class SlideBar extends React.Component<SliderProps, SliderState> {
};
}
getNNIversion = () => {
getNNIversion = (): void => {
axios(`${MANAGER_IP}/version`, {
method: 'GET'
})
......@@ -61,7 +61,7 @@ class SlideBar extends React.Component<SliderProps, SliderState> {
});
}
handleMenuClick = (e: EventPer) => {
handleMenuClick = (e: EventPer): void => {
this.setState({ menuVisible: false });
switch (e.key) {
// to see & download experiment parameters
......@@ -87,11 +87,11 @@ class SlideBar extends React.Component<SliderProps, SliderState> {
}
}
handleVisibleChange = (flag: boolean) => {
handleVisibleChange = (flag: boolean): void => {
this.setState({ menuVisible: flag });
}
getInterval = (value: string) => {
getInterval = (value: string): void => {
if (value === 'close') {
this.props.changeInterval(0);
} else {
......@@ -99,7 +99,7 @@ class SlideBar extends React.Component<SliderProps, SliderState> {
}
}
menu = () => {
menu = (): any => {
return (
<Menu onClick={this.handleMenuClick}>
<Menu.Item key="1">Experiment Parameters</Menu.Item>
......@@ -110,14 +110,14 @@ class SlideBar extends React.Component<SliderProps, SliderState> {
}
// nav bar
navigationBar = () => {
navigationBar = (): any => {
const { version } = this.state;
const feedBackLink = `https://github.com/Microsoft/nni/issues/new?labels=${version}`;
return (
<Menu onClick={this.handleMenuClick} className="menu-list" style={{ width: 216 }}>
{/* <Menu onClick={this.handleMenuClick} className="menu-list" style={{width: window.innerWidth}}> */}
<Menu.Item key="feedback">
<a href={feedBackLink} target="_blank">Feedback</a>
<a href={feedBackLink} rel="noopener noreferrer" target="_blank">Feedback</a>
</Menu.Item>
<Menu.Item key="version">Version: {version}</Menu.Item>
<SubMenu
......@@ -137,7 +137,7 @@ class SlideBar extends React.Component<SliderProps, SliderState> {
);
}
mobileTabs = () => {
mobileTabs = (): any => {
return (
// <Menu className="menuModal" style={{width: 880, position: 'fixed', left: 0, top: 56}}>
<Menu className="menuModal" style={{ padding: '0 10px' }}>
......@@ -147,7 +147,7 @@ class SlideBar extends React.Component<SliderProps, SliderState> {
);
}
refreshInterval = () => {
refreshInterval = (): any => {
const {
form: { getFieldDecorator },
// form: { getFieldDecorator, getFieldValue },
......@@ -171,7 +171,7 @@ class SlideBar extends React.Component<SliderProps, SliderState> {
);
}
select = () => {
select = (): any => {
const { isdisabledFresh } = this.state;
return (
......@@ -189,7 +189,7 @@ class SlideBar extends React.Component<SliderProps, SliderState> {
);
}
fresh = (event: React.SyntheticEvent<EventTarget>) => {
fresh = (event: React.SyntheticEvent<EventTarget>): void => {
event.preventDefault();
event.stopPropagation();
this.setState({ isdisabledFresh: true }, () => {
......@@ -197,7 +197,7 @@ class SlideBar extends React.Component<SliderProps, SliderState> {
});
}
desktopHTML = () => {
desktopHTML = (): any => {
const { version, menuVisible } = this.state;
const feed = `https://github.com/Microsoft/nni/issues/new?labels=${version}`;
return (
......@@ -213,7 +213,7 @@ class SlideBar extends React.Component<SliderProps, SliderState> {
className="fresh"
type="ghost"
>
<a target="_blank" href="https://nni.readthedocs.io/en/latest/Tutorial/WebUI.html">
<a target="_blank" rel="noopener noreferrer" href="https://nni.readthedocs.io/en/latest/Tutorial/WebUI.html">
<img
src={require('../static/img/icon/ques.png')}
alt="question"
......@@ -246,7 +246,7 @@ class SlideBar extends React.Component<SliderProps, SliderState> {
</Dropdown>
</span>
<span className="feedback">
<a href={feed} target="_blank">
<a href={feed} target="_blank" rel="noopener noreferrer">
<img
src={require('../static/img/icon/issue.png')}
alt="NNI github issue"
......@@ -260,7 +260,7 @@ class SlideBar extends React.Component<SliderProps, SliderState> {
);
}
tabeltHTML = () => {
tabeltHTML = (): any => {
return (
<Row className="nav">
<Col className="tabelt-left" span={14}>
......@@ -280,7 +280,7 @@ class SlideBar extends React.Component<SliderProps, SliderState> {
);
}
mobileHTML = () => {
mobileHTML = (): any => {
const { isdisabledFresh } = this.state;
return (
<Row className="nav">
......@@ -319,20 +319,20 @@ class SlideBar extends React.Component<SliderProps, SliderState> {
);
}
// close log drawer (nnimanager.dispatcher)
closeLogDrawer = () => {
closeLogDrawer = (): void => {
this.setState({ isvisibleLogDrawer: false, activeKey: '' });
}
// close download experiment parameters drawer
closeExpDrawer = () => {
closeExpDrawer = (): void => {
this.setState({ isvisibleExperimentDrawer: false });
}
componentDidMount() {
componentDidMount(): void {
this.getNNIversion();
}
render() {
render(): React.ReactNode {
const mobile = (<MediaQuery maxWidth={884}>{this.mobileHTML()}</MediaQuery>);
const tablet = (<MediaQuery minWidth={885} maxWidth={1281}>{this.tabeltHTML()}</MediaQuery>);
const desktop = (<MediaQuery minWidth={1282}>{this.desktopHTML()}</MediaQuery>);
......
......@@ -60,13 +60,15 @@ class TrialsDetail extends React.Component<TrialsDetailProps, TrialDetailState>
tablePageSize: 20,
whichGraph: '1',
searchType: 'id',
searchFilter: trial => true,
// eslint-disable-next-line @typescript-eslint/explicit-function-return-type, @typescript-eslint/no-unused-vars
searchFilter: trial => true
};
}
// search a trial by trial No. & trial id
searchTrial = (event: React.ChangeEvent<HTMLInputElement>) => {
searchTrial = (event: React.ChangeEvent<HTMLInputElement>): void => {
const targetValue = event.target.value;
// eslint-disable-next-line @typescript-eslint/explicit-function-return-type, @typescript-eslint/no-unused-vars
let filter = (trial: Trial) => true;
if (!targetValue.trim()) {
this.setState({ searchFilter: filter });
......@@ -74,17 +76,17 @@ class TrialsDetail extends React.Component<TrialsDetailProps, TrialDetailState>
}
switch (this.state.searchType) {
case 'id':
filter = trial => trial.info.id.toUpperCase().includes(targetValue.toUpperCase());
filter = (trial): boolean => trial.info.id.toUpperCase().includes(targetValue.toUpperCase());
break;
case 'Trial No.':
filter = trial => trial.info.sequenceId.toString() === targetValue;
filter = (trial): boolean => trial.info.sequenceId.toString() === targetValue;
break;
case 'status':
filter = trial => trial.info.status.toUpperCase().includes(targetValue.toUpperCase());
filter = (trial): boolean => trial.info.status.toUpperCase().includes(targetValue.toUpperCase());
break;
case 'parameters':
// TODO: support filters like `x: 2` (instead of `"x": 2`)
filter = trial => JSON.stringify(trial.info.hyperParameters, null, 4).includes(targetValue);
filter = (trial): boolean => JSON.stringify(trial.info.hyperParameters, null, 4).includes(targetValue);
break;
default:
alert(`Unexpected search filter ${this.state.searchType}`);
......@@ -92,15 +94,15 @@ class TrialsDetail extends React.Component<TrialsDetailProps, TrialDetailState>
this.setState({ searchFilter: filter });
}
handleTablePageSizeSelect = (value: string) => {
handleTablePageSizeSelect = (value: string): void => {
this.setState({ tablePageSize: value === 'all' ? -1 : parseInt(value, 10) });
}
handleWhichTabs = (activeKey: string) => {
handleWhichTabs = (activeKey: string): void => {
this.setState({ whichGraph: activeKey });
}
updateSearchFilterType = (value: string) => {
updateSearchFilterType = (value: string): void => {
// clear input value and re-render table
if (this.searchInput !== null) {
this.searchInput.value = '';
......@@ -108,7 +110,7 @@ class TrialsDetail extends React.Component<TrialsDetailProps, TrialDetailState>
this.setState({ searchType: value });
}
render() {
render(): React.ReactNode {
const { tablePageSize, whichGraph } = this.state;
const { columnList, changeColumn } = this.props;
const source = TRIALS.filter(this.state.searchFilter);
......@@ -163,14 +165,14 @@ class TrialsDetail extends React.Component<TrialsDetailProps, TrialDetailState>
<Col span={14} className="right">
<Button
className="common"
onClick={() => { if (this.tableList) { this.tableList.addColumn(); }}}
onClick={(): void => { if (this.tableList) { this.tableList.addColumn(); }}}
>
Add column
</Button>
<Button
className="mediateBtn common"
// use child-component tableList's function, the function is in child-component.
onClick={() => { if (this.tableList) { this.tableList.compareBtn(); }}}
onClick={(): void => { if (this.tableList) { this.tableList.compareBtn(); }}}
>
Compare
</Button>
......@@ -186,6 +188,7 @@ class TrialsDetail extends React.Component<TrialsDetailProps, TrialDetailState>
placeholder={`Search by ${this.state.searchType}`}
onChange={this.searchTrial}
style={{ width: 230 }}
// eslint-disable-next-line @typescript-eslint/explicit-function-return-type
ref={text => (this.searchInput) = text}
/>
</Col>
......@@ -196,7 +199,7 @@ class TrialsDetail extends React.Component<TrialsDetailProps, TrialDetailState>
columnList={columnList}
changeColumn={changeColumn}
trialsUpdateBroadcast={this.props.trialsUpdateBroadcast}
ref={(tabList) => this.tableList = tabList}
ref={(tabList) => this.tableList = tabList} // eslint-disable-line @typescript-eslint/explicit-function-return-type
/>
</div>
);
......
......@@ -21,7 +21,7 @@ class Accuracy extends React.Component<AccuracyProps, {}> {
}
render() {
render(): React.ReactNode {
const { accNodata, accuracyData, height } = this.props;
return (
<div>
......
......@@ -12,7 +12,7 @@ class BasicInfo extends React.Component<BasicInfoProps, {}> {
super(props);
}
render() {
render(): React.ReactNode {
return (
<Row className="main">
<Col span={8} className="padItem basic">
......
......@@ -18,22 +18,22 @@ class ConcurrencyInput extends React.Component<ConcurrencyInputProps, Concurrenc
this.state = { editting: false };
}
save = () => {
save = (): void => {
if (this.input.current !== null) {
this.props.updateValue(this.input.current.value);
this.setState({ editting: false });
}
}
cancel = () => {
cancel = (): void => {
this.setState({ editting: false });
}
edit = () => {
edit = (): void => {
this.setState({ editting: true });
}
render() {
render(): React.ReactNode {
if (this.state.editting) {
return (
<Row className="inputBox">
......
......@@ -29,7 +29,7 @@ class Progressed extends React.Component<ProgressProps, ProgressState> {
};
}
editTrialConcurrency = async (userInput: string) => {
editTrialConcurrency = async (userInput: string): Promise<void> => {
if (!userInput.match(/^[1-9]\d*$/)) {
message.error('Please enter a positive integer!', 2);
return;
......@@ -46,6 +46,7 @@ class Progressed extends React.Component<ProgressProps, ProgressState> {
// rest api, modify trial concurrency value
try {
const res = await axios.put(`${MANAGER_IP}/experiment`, newProfile, {
// eslint-disable-next-line @typescript-eslint/camelcase
params: { update_type: 'TRIAL_CONCURRENCY' }
});
if (res.status === 200) {
......@@ -66,20 +67,22 @@ class Progressed extends React.Component<ProgressProps, ProgressState> {
}
}
isShowDrawer = () => {
isShowDrawer = (): void => {
this.setState({ isShowLogDrawer: true });
}
closeDrawer = () => {
closeDrawer = (): void => {
this.setState({ isShowLogDrawer: false });
}
render() {
render(): React.ReactNode {
const { bestAccuracy } = this.props;
const { isShowLogDrawer } = this.state;
const count = TRIALS.countStatus();
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
const stoppedCount = count.get('USER_CANCELED')! + count.get('SYS_CANCELED')! + count.get('EARLY_STOPPED')!;
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
const bar2 = count.get('RUNNING')! + count.get('SUCCEEDED')! + count.get('FAILED')! + stoppedCount;
const bar2Percent = (bar2 / EXPERIMENT.profile.params.maxTrialNum) * 100;
......@@ -98,6 +101,7 @@ class Progressed extends React.Component<ProgressProps, ProgressState> {
</div>
);
}
return (
<Row className="progress" id="barBack">
<Row className="basic lineBasic">
......
......@@ -16,7 +16,7 @@ class ProgressBar extends React.Component<ProItemProps, {}> {
}
render() {
render(): React.ReactNode {
const { who, percent, description, maxString, bgclass } = this.props;
return (
......@@ -31,7 +31,7 @@ class ProgressBar extends React.Component<ProItemProps, {}> {
percent={percent}
strokeWidth={30}
// strokeLinecap={'square'}
format={() => description}
format={(): string => description}
/>
</div>
<Row className="description">
......
......@@ -13,7 +13,7 @@ class SearchSpace extends React.Component<SearchspaceProps, {}> {
}
render() {
render(): React.ReactNode {
const { searchSpace } = this.props;
return (
<div className="searchSpace">
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment