"test/ut/git@developer.sourcefind.cn:OpenDAS/nni.git" did not exist on "2566badb06095b9e3ea16eb6f00fd58da65a95fd"
Unverified Commit 755ac5f0 authored by SparkSnail's avatar SparkSnail Committed by GitHub
Browse files

Merge pull request #199 from microsoft/master

merge master
parents 521ca508 f23f8a06
...@@ -38,6 +38,7 @@ build: ...@@ -38,6 +38,7 @@ build:
cd $(CWD)../../src/webui && $(NNI_YARN) && $(NNI_YARN) build cd $(CWD)../../src/webui && $(NNI_YARN) && $(NNI_YARN) build
rm -rf $(CWD)nni rm -rf $(CWD)nni
cp -r $(CWD)../../src/nni_manager/dist $(CWD)nni cp -r $(CWD)../../src/nni_manager/dist $(CWD)nni
cp -r $(CWD)../../src/nni_manager/config $(CWD)nni
cp -r $(CWD)../../src/webui/build $(CWD)nni/static cp -r $(CWD)../../src/webui/build $(CWD)nni/static
cp $(CWD)../../src/nni_manager/package.json $(CWD)nni cp $(CWD)../../src/nni_manager/package.json $(CWD)nni
sed -ie 's/$(NNI_VERSION_TEMPLATE)/$(NNI_VERSION_VALUE)/' $(CWD)nni/package.json sed -ie 's/$(NNI_VERSION_TEMPLATE)/$(NNI_VERSION_VALUE)/' $(CWD)nni/package.json
......
# General Programming Interface for Neural Architecture Search (experimental feature) # General Programming Interface for Neural Architecture Search (experimental feature)
_*This is an experimental feature, currently, we only implemented the general NAS programming interface. Weight sharing and one-shot NAS based on this programming interface will be supported in the following releases._ _*This is an experimental feature, currently, we only implemented the general NAS programming interface. Weight sharing will be supported in the following releases._
Automatic neural architecture search is taking an increasingly important role on finding better models. Recent research works have proved the feasibility of automatic NAS, and also found some models that could beat manually designed and tuned models. Some of representative works are [NASNet][2], [ENAS][1], [DARTS][3], [Network Morphism][4], and [Evolution][5]. There are new innovations keeping emerging. However, it takes great efforts to implement those algorithms, and it is hard to reuse code base of one algorithm for implementing another. Automatic neural architecture search is taking an increasingly important role on finding better models. Recent research works have proved the feasibility of automatic NAS, and also found some models that could beat manually designed and tuned models. Some of representative works are [NASNet][2], [ENAS][1], [DARTS][3], [Network Morphism][4], and [Evolution][5]. There are new innovations keeping emerging. However, it takes great efforts to implement those algorithms, and it is hard to reuse code base of one algorithm for implementing another.
To facilitate NAS innovations (e.g., design/implement new NAS models, compare different NAS models side-by-side), an easy-to-use and flexible programming interface is crucial. To facilitate NAS innovations (e.g., design/implement new NAS models, compare different NAS models side-by-side), an easy-to-use and flexible programming interface is crucial.
<a name="ProgInterface"></a>
## Programming interface ## Programming interface
A new programming interface for designing and searching for a model is often demanded in two scenarios. 1) When designing a neural network, the designer may have multiple choices for a layer, sub-model, or connection, and not sure which one or a combination performs the best. It would be appealing to have an easy way to express the candidate layers/sub-models they want to try. 2) For the researchers who are working on automatic NAS, they want to have an unified way to express the search space of neural architectures. And making unchanged trial code adapted to different searching algorithms. A new programming interface for designing and searching for a model is often demanded in two scenarios. 1) When designing a neural network, the designer may have multiple choices for a layer, sub-model, or connection, and not sure which one or a combination performs the best. It would be appealing to have an easy way to express the candidate layers/sub-models they want to try. 2) For the researchers who are working on automatic NAS, they want to have an unified way to express the search space of neural architectures. And making unchanged trial code adapted to different searching algorithms.
...@@ -53,13 +55,16 @@ After finishing the trial code through the annotation above, users have implicit ...@@ -53,13 +55,16 @@ After finishing the trial code through the annotation above, users have implicit
```javascript ```javascript
{ {
"mutable_1": { "mutable_1": {
"layer_1": { "_type": "mutable_layer",
"layer_choice": ["conv(ch=128)", "pool", "identity"], "_value": {
"optional_inputs": ["out1", "out2", "out3"], "layer_1": {
"optional_input_size": 2 "layer_choice": ["conv(ch=128)", "pool", "identity"],
}, "optional_inputs": ["out1", "out2", "out3"],
"layer_2": { "optional_input_size": 2
... },
"layer_2": {
...
}
} }
} }
} }
...@@ -83,9 +88,109 @@ Accordingly, a specified neural architecture (generated by tuning algorithm) is ...@@ -83,9 +88,109 @@ Accordingly, a specified neural architecture (generated by tuning algorithm) is
With the specification of the format of search space and architecture (choice) expression, users are free to implement various (general) tuning algorithms for neural architecture search on NNI. One future work is to provide a general NAS algorithm. With the specification of the format of search space and architecture (choice) expression, users are free to implement various (general) tuning algorithms for neural architecture search on NNI. One future work is to provide a general NAS algorithm.
## Support of One-Shot NAS
One-Shot NAS is a popular approach to find good neural architecture within a limited time and resource budget. Basically, it builds a full graph based on the search space, and uses gradient descent to at last find the best subgraph. There are different training approaches, such as [training subgraphs (per mini-batch)][1], [training full graph through dropout][6], [training with architecture weights (regularization)][3].
NNI has supported the general NAS as demonstrated above. From users' point of view, One-Shot NAS and NAS have the same search space specification, thus, they could share the same programming interface as demonstrated above, just different training modes. NNI provides four training modes:
**\*classic_mode\***: this mode is described [above](#ProgInterface), in this mode, each subgraph runs as a trial job. To use this mode, you should enable NNI annotation and specify a tuner for nas in experiment config file. [Here](https://github.com/microsoft/nni/tree/master/examples/trials/mnist-nas) is an example to show how to write trial code and the config file. And [here](https://github.com/microsoft/nni/tree/master/examples/tuners/random_nas_tuner) is a simple tuner for nas.
**\*enas_mode\***: following the training approach in [enas paper][1]. It builds the full graph based on neural architrecture search space, and only activate one subgraph that generated by the controller for each mini-batch. [Detailed Description](#ENASMode). (currently only supported on tensorflow).
To use enas_mode, you should add one more field in the `trial` config as shown below.
```diff
trial:
command: your command to run the trial
codeDir: the directory where the trial's code is located
gpuNum: the number of GPUs that one trial job needs
+ #choice: classic_mode, enas_mode, oneshot_mode
+ nasMode: enas_mode
```
Similar to classic_mode, in enas_mode you need to specify a tuner for nas, as it also needs to receive subgraphs from tuner (or controller using the terminology in the paper). Since this trial job needs to receive multiple subgraphs from tuner, each one for a mini-batch, two lines need to be added to the trial code to receive the next subgraph (i.e., `nni.training_update`) and report the result of the current subgraph. Below is an example:
```python
for _ in range(num):
# here receives and enables a new subgraph
"""@nni.training_update(tf=tf, session=self.session)"""
loss, _ = self.session.run([loss_op, train_op])
# report the loss of this mini-batch
"""@nni.report_final_result(loss)"""
```
Here, `nni.training_update` is to do some update on the full graph. In enas_mode, the update means receiving a subgraph and enabling it on the next mini-batch. While in darts_mode, the update means training the architecture weights (details in darts_mode). In enas_mode, you need to pass the imported tensorflow package to `tf` and the session to `session`.
**\*oneshot_mode\***: following the training approach in [this paper][6]. Different from enas_mode which trains the full graph by training large numbers of subgraphs, in oneshot_mode the full graph is built and dropout is added to candidate inputs and also added to candidate ops' outputs. Then this full graph is trained like other DL models. [Detailed Description](#OneshotMode). (currently only supported on tensorflow).
To use oneshot_mode, you should add one more field in the `trial` config as shown below. In this mode, no need to specify tuner in the config file as it does not need tuner. (Note that you still need to specify a tuner (any tuner) in the config file for now.) Also, no need to add `nni.training_update` in this mode, because no special processing (or update) is needed during training.
```diff
trial:
command: your command to run the trial
codeDir: the directory where the trial's code is located
gpuNum: the number of GPUs that one trial job needs
+ #choice: classic_mode, enas_mode, oneshot_mode
+ nasMode: oneshot_mode
```
**\*darts_mode\***: following the training approach in [this paper][3]. It is similar to oneshot_mode. There are two differences, one is that darts_mode only add architecture weights to the outputs of candidate ops, the other is that it trains model weights and architecture weights in an interleaved manner. [Detailed Description](#DartsMode).
To use darts_mode, you should add one more field in the `trial` config as shown below. In this mode, also no need to specify tuner in the config file as it does not need tuner. (Note that you still need to specify a tuner (any tuner) in the config file for now.)
```diff
trial:
command: your command to run the trial
codeDir: the directory where the trial's code is located
gpuNum: the number of GPUs that one trial job needs
+ #choice: classic_mode, enas_mode, oneshot_mode
+ nasMode: darts_mode
```
When using darts_mode, you need to call `nni.training_update` as shown below when architecture weights should be updated. Updating architecture weights need `loss` for updating the weights as well as the training data (i.e., `feed_dict`) for it.
```python
for _ in range(num):
# here trains the architecture weights
"""@nni.training_update(tf=tf, session=self.session, loss=loss, feed_dict=feed_dict)"""
loss, _ = self.session.run([loss_op, train_op])
```
**Note:** for enas_mode, oneshot_mode, and darts_mode, NNI only works on the training phase. They also have their own inference phase which is not handled by NNI. For enas_mode, the inference phase is to generate new subgraphs through the controller. For oneshot_mode, the inference phase is sampling new subgraphs randomly and choosing good ones. For darts_mode, the inference phase is pruning a proportion of candidates ops based on architecture weights.
<a name="ENASMode"></a>
### enas_mode
In enas_mode, the compiled trial code builds the full graph (rather than subgraph), it receives a chosen architecture and training this architecture on the full graph for a mini-batch, then request another chosen architecture. It is supported by [NNI multi-phase](./multiPhase.md).
Specifically, for trials using tensorflow, we create and use tensorflow variable as signals, and tensorflow conditional functions to control the search space (full-graph) to be more flexible, which means it can be changed into different sub-graphs (multiple times) depending on these signals. [Here]() is an example for enas_mode.
<a name="OneshotMode"></a>
### oneshot_mode
Below is the figure to show where dropout is added to the full graph for one layer in `nni.mutable_layers`, input 1-k are candidate inputs, the four ops are candidate ops.
![](../../img/oneshot_mode.png)
As suggested in the [paper][6], a dropout method is implemented to the inputs for every layer. The dropout rate is set to r^(1/k), where 0 < r < 1 is a hyper-parameter of the model (default to be 0.01) and k is number of optional inputs for a specific layer. The higher the fan-in, the more likely each possible input is to be dropped out. However, the probability of dropping out all optional_inputs of a layer is kept constant regardless of its fan-in. Suppose r = 0.05. If a layer has k = 2 optional_inputs then each one will independently be dropped out with probability 0.051/2 ≈ 0.22 and will be retained with probability 0.78. If a layer has k = 7 optional_inputs then each one will independently be dropped out with probability 0.051/7 ≈ 0.65 and will be retained with probability 0.35. In both cases, the probability of dropping out all of the layer's optional_inputs is 5%. The outputs of candidate ops are dropped out through the same way. [Here]() is an example for oneshot_mode.
<a name="DartsMode"></a>
### darts_mode
Below is the figure to show where architecture weights are added to the full graph for one layer in `nni.mutable_layers`, output of each candidate op is multiplied by a weight which is called architecture weight.
![](../../img/darts_mode.png)
In `nni.training_update`, tensorflow MomentumOptimizer is used to train the architecture weights based on the pass `loss` and `feed_dict`. [Here]() is an example for darts_mode.
### [__TODO__] Multiple trial jobs for One-Shot NAS
One-Shot NAS usually has only one trial job with the full graph. However, running multiple such trial jobs leads to benefits. For example, in enas_mode multiple trial jobs could share the weights of the full graph to speedup the model training (or converge). Some One-Shot approaches are not stable, running multiple trial jobs increase the possibility of finding better models.
NNI natively supports running multiple such trial jobs. The figure below shows how multiple trial jobs run on NNI.
![](../../img/one-shot_training.png)
============================================================= =============================================================
## Neural architecture search on NNI ## System design of NAS on NNI
### Basic flow of experiment execution ### Basic flow of experiment execution
...@@ -95,7 +200,7 @@ NNI's annotation compiler transforms the annotated trial code to the code that c ...@@ -95,7 +200,7 @@ NNI's annotation compiler transforms the annotated trial code to the code that c
The above figure shows how the trial code runs on NNI. `nnictl` processes user trial code to generate a search space file and compiled trial code. The former is fed to tuner, and the latter is used to run trials. The above figure shows how the trial code runs on NNI. `nnictl` processes user trial code to generate a search space file and compiled trial code. The former is fed to tuner, and the latter is used to run trials.
[Simple example of NAS on NNI](https://github.com/microsoft/nni/tree/v0.8/examples/trials/mnist-nas). [Simple example of NAS on NNI](https://github.com/microsoft/nni/tree/master/examples/trials/mnist-nas).
### [__TODO__] Weight sharing ### [__TODO__] Weight sharing
...@@ -107,24 +212,9 @@ We believe weight sharing (transferring) plays a key role on speeding up NAS, wh ...@@ -107,24 +212,9 @@ We believe weight sharing (transferring) plays a key role on speeding up NAS, wh
Example of weight sharing on NNI. Example of weight sharing on NNI.
### [__TODO__] Support of One-Shot NAS ## General tuning algorithms for NAS
One-Shot NAS is a popular approach to find good neural architecture within a limited time and resource budget. Basically, it builds a full graph based on the search space, and uses gradient descent to at last find the best subgraph. There are different training approaches, such as [training subgraphs (per mini-batch)][1], [training full graph through dropout][6], [training with architecture weights (regularization)][3]. Here we focus on the first approach, i.e., training subgraphs (ENAS).
With the same annotated trial code, users could choose One-Shot NAS as execution mode on NNI. Specifically, the compiled trial code builds the full graph (rather than subgraph demonstrated above), it receives a chosen architecture and training this architecture on the full graph for a mini-batch, then request another chosen architecture. It is supported by [NNI multi-phase](MultiPhase.md). We support this training approach because training a subgraph is very fast, building the graph every time training a subgraph induces too much overhead.
![](../../img/one-shot_training.png)
The design of One-Shot NAS on NNI is shown in the above figure. One-Shot NAS usually only has one trial job with full graph. NNI supports running multiple such trial jobs each of which runs independently. As One-Shot NAS is not stable, running multiple instances helps find better model. Moreover, trial jobs are also able to synchronize weights during running (i.e., there is only one copy of weights, like asynchronous parameter-server mode). This may speedup converge.
Example of One-Shot NAS on NNI.
## [__TODO__] General tuning algorithms for NAS
Like hyperparameter tuning, a relatively general algorithm for NAS is required. The general programming interface makes this task easier to some extent. We have a RL-based tuner algorithm for NAS from our contributors. We expect efforts from community to design and implement better NAS algorithms.
More tuning algorithms for NAS. Like hyperparameter tuning, a relatively general algorithm for NAS is required. The general programming interface makes this task easier to some extent. We have an [RL tuner based on PPO algorithm](https://github.com/microsoft/nni/tree/master/src/sdk/pynni/nni/ppo_tuner) for NAS. We expect efforts from community to design and implement better NAS algorithms.
## [__TODO__] Export best neural architecture and code ## [__TODO__] Export best neural architecture and code
......
...@@ -14,7 +14,7 @@ Currently we support the following algorithms: ...@@ -14,7 +14,7 @@ Currently we support the following algorithms:
|[__Naïve Evolution__](#Evolution)|Naïve Evolution comes from Large-Scale Evolution of Image Classifiers. It randomly initializes a population-based on search space. For each generation, it chooses better ones and does some mutation (e.g., change a hyperparameter, add/remove one layer) on them to get the next generation. Naïve Evolution requires many trials to works, but it's very simple and easy to expand new features. [Reference paper](https://arxiv.org/pdf/1703.01041.pdf)| |[__Naïve Evolution__](#Evolution)|Naïve Evolution comes from Large-Scale Evolution of Image Classifiers. It randomly initializes a population-based on search space. For each generation, it chooses better ones and does some mutation (e.g., change a hyperparameter, add/remove one layer) on them to get the next generation. Naïve Evolution requires many trials to works, but it's very simple and easy to expand new features. [Reference paper](https://arxiv.org/pdf/1703.01041.pdf)|
|[__SMAC__](#SMAC)|SMAC is based on Sequential Model-Based Optimization (SMBO). It adapts the most prominent previously used model class (Gaussian stochastic process models) and introduces the model class of random forests to SMBO, in order to handle categorical parameters. The SMAC supported by NNI is a wrapper on the SMAC3 GitHub repo. Notice, SMAC need to be installed by `nnictl package` command. [Reference Paper,](https://www.cs.ubc.ca/~hutter/papers/10-TR-SMAC.pdf) [GitHub Repo](https://github.com/automl/SMAC3)| |[__SMAC__](#SMAC)|SMAC is based on Sequential Model-Based Optimization (SMBO). It adapts the most prominent previously used model class (Gaussian stochastic process models) and introduces the model class of random forests to SMBO, in order to handle categorical parameters. The SMAC supported by NNI is a wrapper on the SMAC3 GitHub repo. Notice, SMAC need to be installed by `nnictl package` command. [Reference Paper,](https://www.cs.ubc.ca/~hutter/papers/10-TR-SMAC.pdf) [GitHub Repo](https://github.com/automl/SMAC3)|
|[__Batch tuner__](#Batch)|Batch tuner allows users to simply provide several configurations (i.e., choices of hyper-parameters) for their trial code. After finishing all the configurations, the experiment is done. Batch tuner only supports the type choice in search space spec.| |[__Batch tuner__](#Batch)|Batch tuner allows users to simply provide several configurations (i.e., choices of hyper-parameters) for their trial code. After finishing all the configurations, the experiment is done. Batch tuner only supports the type choice in search space spec.|
|[__Grid Search__](#GridSearch)|Grid Search performs an exhaustive searching through a manually specified subset of the hyperparameter space defined in the searchspace file. Note that the only acceptable types of search space are choice, quniform, qloguniform. The number q in quniform and qloguniform has special meaning (different from the spec in search space spec). It means the number of values that will be sampled evenly from the range low and high.| |[__Grid Search__](#GridSearch)|Grid Search performs an exhaustive searching through a manually specified subset of the hyperparameter space defined in the searchspace file. Note that the only acceptable types of search space are choice, quniform, randint. |
|[__Hyperband__](#Hyperband)|Hyperband tries to use the limited resource to explore as many configurations as possible, and finds out the promising ones to get the final result. The basic idea is generating many configurations and to run them for the small number of trial budget to find out promising one, then further training those promising ones to select several more promising one.[Reference Paper](https://arxiv.org/pdf/1603.06560.pdf)| |[__Hyperband__](#Hyperband)|Hyperband tries to use the limited resource to explore as many configurations as possible, and finds out the promising ones to get the final result. The basic idea is generating many configurations and to run them for the small number of trial budget to find out promising one, then further training those promising ones to select several more promising one.[Reference Paper](https://arxiv.org/pdf/1603.06560.pdf)|
|[__Network Morphism__](#NetworkMorphism)|Network Morphism provides functions to automatically search for architecture of deep learning models. Every child network inherits the knowledge from its parent network and morphs into diverse types of networks, including changes of depth, width, and skip-connection. Next, it estimates the value of a child network using the historic architecture and metric pairs. Then it selects the most promising one to train. [Reference Paper](https://arxiv.org/abs/1806.10282)| |[__Network Morphism__](#NetworkMorphism)|Network Morphism provides functions to automatically search for architecture of deep learning models. Every child network inherits the knowledge from its parent network and morphs into diverse types of networks, including changes of depth, width, and skip-connection. Next, it estimates the value of a child network using the historic architecture and metric pairs. Then it selects the most promising one to train. [Reference Paper](https://arxiv.org/abs/1806.10282)|
|[__Metis Tuner__](#MetisTuner)|Metis offers the following benefits when it comes to tuning parameters: While most tools only predict the optimal configuration, Metis gives you two outputs: (a) current prediction of optimal configuration, and (b) suggestion for the next trial. No more guesswork. While most tools assume training datasets do not have noisy data, Metis actually tells you if you need to re-sample a particular hyper-parameter. [Reference Paper](https://www.microsoft.com/en-us/research/publication/metis-robustly-tuning-tail-latencies-cloud-systems/)| |[__Metis Tuner__](#MetisTuner)|Metis offers the following benefits when it comes to tuning parameters: While most tools only predict the optimal configuration, Metis gives you two outputs: (a) current prediction of optimal configuration, and (b) suggestion for the next trial. No more guesswork. While most tools assume training datasets do not have noisy data, Metis actually tells you if you need to re-sample a particular hyper-parameter. [Reference Paper](https://www.microsoft.com/en-us/research/publication/metis-robustly-tuning-tail-latencies-cloud-systems/)|
...@@ -213,7 +213,7 @@ The search space file including the high-level key `combine_params`. The type of ...@@ -213,7 +213,7 @@ The search space file including the high-level key `combine_params`. The type of
**Suggested scenario** **Suggested scenario**
Note that the only acceptable types of search space are `choice`, `quniform`, `qloguniform`. **The number `q` in `quniform` and `qloguniform` has special meaning (different from the spec in [search space spec](../Tutorial/SearchSpaceSpec.md)). It means the number of values that will be sampled evenly from the range `low` and `high`.** Note that the only acceptable types of search space are `choice`, `quniform`, `randint`.
It is suggested when search space is small, it is feasible to exhaustively sweeping the whole search space. [Detailed Description](./GridsearchTuner.md) It is suggested when search space is small, it is feasible to exhaustively sweeping the whole search space. [Detailed Description](./GridsearchTuner.md)
......
...@@ -87,7 +87,7 @@ All types of sampling strategies and their parameter are listed here: ...@@ -87,7 +87,7 @@ All types of sampling strategies and their parameter are listed here:
| Evolution Tuner | &#10003; | &#10003; | &#10003; | &#10003; | &#10003; | &#10003; | &#10003; | &#10003; | &#10003; | &#10003; | | Evolution Tuner | &#10003; | &#10003; | &#10003; | &#10003; | &#10003; | &#10003; | &#10003; | &#10003; | &#10003; | &#10003; |
| SMAC Tuner | &#10003; | &#10003; | &#10003; | &#10003; | &#10003; | | | | | | | SMAC Tuner | &#10003; | &#10003; | &#10003; | &#10003; | &#10003; | | | | | |
| Batch Tuner | &#10003; | | | | | | | | | | | Batch Tuner | &#10003; | | | | | | | | | |
| Grid Search Tuner | &#10003; | | | &#10003; | | &#10003; | | | | | | Grid Search Tuner | &#10003; | &#10003; | | &#10003; | | | | | | |
| Hyperband Advisor | &#10003; | &#10003; | &#10003; | &#10003; | &#10003; | &#10003; | &#10003; | &#10003; | &#10003; | &#10003; | | Hyperband Advisor | &#10003; | &#10003; | &#10003; | &#10003; | &#10003; | &#10003; | &#10003; | &#10003; | &#10003; | &#10003; |
| Metis Tuner | &#10003; | &#10003; | &#10003; | &#10003; | | | | | | | | Metis Tuner | &#10003; | &#10003; | &#10003; | &#10003; | | | | | | |
| GP Tuner | &#10003; | &#10003; | &#10003; | &#10003; | &#10003; | &#10003; | | | | | | GP Tuner | &#10003; | &#10003; | &#10003; | &#10003; | &#10003; | &#10003; | | | | |
...@@ -95,12 +95,6 @@ All types of sampling strategies and their parameter are listed here: ...@@ -95,12 +95,6 @@ All types of sampling strategies and their parameter are listed here:
Known Limitations: Known Limitations:
* Note that In Grid Search Tuner, for users' convenience, the definition of `quniform` and `qloguniform` change, where q here specifies the number of values that will be sampled. Details about them are listed as follows
* Type 'quniform' will receive three values [low, high, q], where [low, high] specifies a range and 'q' specifies the number of values that will be sampled evenly. Note that q should be at least 2. It will be sampled in a way that the first sampled value is 'low', and each of the following values is (high-low)/q larger that the value in front of it.
* Type 'qloguniform' behaves like 'quniform' except that it will first change the range to [log(low), log(high)] and sample and then change the sampled value back.
* Note that Metis Tuner only supports numerical `choice` now * Note that Metis Tuner only supports numerical `choice` now
* Note that for nested search space: * Note that for nested search space:
......
...@@ -39,16 +39,14 @@ logger = logging.getLogger('grid_search_AutoML') ...@@ -39,16 +39,14 @@ logger = logging.getLogger('grid_search_AutoML')
class GridSearchTuner(Tuner): class GridSearchTuner(Tuner):
''' '''
GridSearchTuner will search all the possible configures that the user define in the searchSpace. GridSearchTuner will search all the possible configures that the user define in the searchSpace.
The only acceptable types of search space are 'quniform', 'qloguniform' and 'choice' The only acceptable types of search space are 'choice', 'quniform', 'randint'
Type 'choice' will select one of the options. Note that it can also be nested. Type 'choice' will select one of the options. Note that it can also be nested.
Type 'quniform' will receive three values [low, high, q], where [low, high] specifies a range and 'q' specifies the number of values that will be sampled evenly. Type 'quniform' will receive three values [low, high, q], where [low, high] specifies a range and 'q' specifies the interval
Note that q should be at least 2. It will be sampled in a way that the first sampled value is 'low', and each of the following values is 'interval' larger than the value in front of it.
It will be sampled in a way that the first sampled value is 'low', and each of the following values is (high-low)/q larger that the value in front of it.
Type 'qloguniform' behaves like 'quniform' except that it will first change the range to [log(low), log(high)] Type 'randint' gives all possible intergers in range[low, high). Note that 'high' is not included.
and sample and then change the sampled value back.
''' '''
def __init__(self): def __init__(self):
...@@ -73,8 +71,12 @@ class GridSearchTuner(Tuner): ...@@ -73,8 +71,12 @@ class GridSearchTuner(Tuner):
chosen_params.extend(choice) chosen_params.extend(choice)
else: else:
chosen_params.append(choice) chosen_params.append(choice)
elif _type == 'quniform':
chosen_params = self._parse_quniform(_value)
elif _type == 'randint':
chosen_params = self._parse_randint(_value)
else: else:
chosen_params = self.parse_qtype(_type, _value) raise RuntimeError("Not supported type: %s" % _type)
else: else:
chosen_params = dict() chosen_params = dict()
for key in ss_spec.keys(): for key in ss_spec.keys():
...@@ -95,21 +97,13 @@ class GridSearchTuner(Tuner): ...@@ -95,21 +97,13 @@ class GridSearchTuner(Tuner):
def _parse_quniform(self, param_value): def _parse_quniform(self, param_value):
'''parse type of quniform parameter and return a list''' '''parse type of quniform parameter and return a list'''
if param_value[2] < 2: low, high, interval = param_value[0], param_value[1], param_value[2]
raise RuntimeError("The number of values sampled (q) should be at least 2") count = int(np.floor((high - low) / interval)) + 1
low, high, count = param_value[0], param_value[1], param_value[2] return [low + interval * i for i in range(count)]
interval = (high - low) / (count - 1)
return [float(low + interval * i) for i in range(count)] def _parse_randint(self, param_value):
'''parse type of randint parameter and return a list'''
def parse_qtype(self, param_type, param_value): return np.arange(param_value[0], param_value[1]).tolist()
'''parse type of quniform or qloguniform'''
if param_type == 'quniform':
return self._parse_quniform(param_value)
if param_type == 'qloguniform':
param_value[:2] = np.log(param_value[:2])
return list(np.exp(self._parse_quniform(param_value)))
raise RuntimeError("Not supported type: %s" % param_type)
def expand_parameters(self, para): def expand_parameters(self, para):
''' '''
...@@ -133,7 +127,7 @@ class GridSearchTuner(Tuner): ...@@ -133,7 +127,7 @@ class GridSearchTuner(Tuner):
def update_search_space(self, search_space): def update_search_space(self, search_space):
''' '''
Check if the search space is valid and expand it: only contains 'choice' type or other types beginnning with the letter 'q' Check if the search space is valid and expand it: support only 'choice', 'quniform', randint'
''' '''
self.expanded_search_space = self.json2parameter(search_space) self.expanded_search_space = self.json2parameter(search_space)
......
import * as React from 'react'; import * as React from 'react';
import { Row, Modal } from 'antd'; import { Row, Modal } from 'antd';
import ReactEcharts from 'echarts-for-react'; import ReactEcharts from 'echarts-for-react';
import IntermediateVal from '../public-child/IntermediateVal'; import IntermediateVal from '../public-child/IntermediateVal';
import '../../static/style/compare.scss'; import '../../static/style/compare.scss';
import { TableObj, Intermedia, TooltipForIntermediate } from 'src/static/interface'; import { TableObj, Intermedia, TooltipForIntermediate } from 'src/static/interface';
...@@ -83,13 +83,13 @@ class Compare extends React.Component<CompareProps, {}> { ...@@ -83,13 +83,13 @@ class Compare extends React.Component<CompareProps, {}> {
}, },
xAxis: { xAxis: {
type: 'category', type: 'category',
name: 'Step', // name: '# Intermeidate',
boundaryGap: false, boundaryGap: false,
data: xAxis data: xAxis
}, },
yAxis: { yAxis: {
type: 'value', type: 'value',
name: 'metric' name: 'Metric'
}, },
series: trialIntermediate series: trialIntermediate
}; };
...@@ -137,7 +137,7 @@ class Compare extends React.Component<CompareProps, {}> { ...@@ -137,7 +137,7 @@ class Compare extends React.Component<CompareProps, {}> {
const temp = compareRows[index]; const temp = compareRows[index];
return ( return (
<td className="value" key={index}> <td className="value" key={index}>
<IntermediateVal record={temp}/> <IntermediateVal record={temp} />
</td> </td>
); );
})} })}
...@@ -193,9 +193,11 @@ class Compare extends React.Component<CompareProps, {}> { ...@@ -193,9 +193,11 @@ class Compare extends React.Component<CompareProps, {}> {
destroyOnClose={true} destroyOnClose={true}
maskClosable={false} maskClosable={false}
width="90%" width="90%"
// centered={true}
> >
<Row>{this.intermediate()}</Row> <Row className="compare-intermeidate">
{this.intermediate()}
<Row className="compare-yAxis"># Intermeidate</Row>
</Row>
<Row>{this.initColumn()}</Row> <Row>{this.initColumn()}</Row>
</Modal> </Modal>
); );
......
...@@ -3,7 +3,7 @@ import axios from 'axios'; ...@@ -3,7 +3,7 @@ import axios from 'axios';
import { MANAGER_IP } from '../static/const'; import { MANAGER_IP } from '../static/const';
import { Row, Col, Tabs, Select, Button, Icon } from 'antd'; import { Row, Col, Tabs, Select, Button, Icon } from 'antd';
const Option = Select.Option; const Option = Select.Option;
import { TableObj, Parameters } from '../static/interface'; import { TableObj, Parameters, ExperimentInfo } from '../static/interface';
import { getFinal } from '../static/function'; import { getFinal } from '../static/function';
import DefaultPoint from './trial-detail/DefaultMetricPoint'; import DefaultPoint from './trial-detail/DefaultMetricPoint';
import Duration from './trial-detail/Duration'; import Duration from './trial-detail/Duration';
...@@ -21,8 +21,6 @@ interface TrialDetailState { ...@@ -21,8 +21,6 @@ interface TrialDetailState {
tableListSource: Array<TableObj>; tableListSource: Array<TableObj>;
searchResultSource: Array<TableObj>; searchResultSource: Array<TableObj>;
isHasSearch: boolean; isHasSearch: boolean;
experimentStatus: string;
experimentPlatform: string;
experimentLogCollection: boolean; experimentLogCollection: boolean;
entriesTable: number; // table components val entriesTable: number; // table components val
entriesInSelect: string; entriesInSelect: string;
...@@ -32,6 +30,7 @@ interface TrialDetailState { ...@@ -32,6 +30,7 @@ interface TrialDetailState {
hyperCounts: number; // user click the hyper-parameter counts hyperCounts: number; // user click the hyper-parameter counts
durationCounts: number; durationCounts: number;
intermediateCounts: number; intermediateCounts: number;
experimentInfo: ExperimentInfo;
searchFilter: string; searchFilter: string;
searchPlaceHolder: string; searchPlaceHolder: string;
} }
...@@ -78,8 +77,6 @@ class TrialsDetail extends React.Component<TrialsDetailProps, TrialDetailState> ...@@ -78,8 +77,6 @@ class TrialsDetail extends React.Component<TrialsDetailProps, TrialDetailState>
accNodata: '', accNodata: '',
tableListSource: [], tableListSource: [],
searchResultSource: [], searchResultSource: [],
experimentStatus: '',
experimentPlatform: '',
experimentLogCollection: false, experimentLogCollection: false,
entriesTable: 20, entriesTable: 20,
entriesInSelect: '20', entriesInSelect: '20',
...@@ -90,6 +87,10 @@ class TrialsDetail extends React.Component<TrialsDetailProps, TrialDetailState> ...@@ -90,6 +87,10 @@ class TrialsDetail extends React.Component<TrialsDetailProps, TrialDetailState>
hyperCounts: 0, hyperCounts: 0,
durationCounts: 0, durationCounts: 0,
intermediateCounts: 0, intermediateCounts: 0,
experimentInfo: {
platform: '',
optimizeMode: 'maximize'
},
searchFilter: 'id', searchFilter: 'id',
searchPlaceHolder: 'Search by id' searchPlaceHolder: 'Search by id'
}; };
...@@ -326,7 +327,7 @@ class TrialsDetail extends React.Component<TrialsDetailProps, TrialDetailState> ...@@ -326,7 +327,7 @@ class TrialsDetail extends React.Component<TrialsDetailProps, TrialDetailState>
}) })
.then(res => { .then(res => {
if (res.status === 200) { if (res.status === 200) {
const trainingPlatform = res.data.params.trainingServicePlatform !== undefined const trainingPlatform: string = res.data.params.trainingServicePlatform !== undefined
? ?
res.data.params.trainingServicePlatform res.data.params.trainingServicePlatform
: :
...@@ -336,12 +337,24 @@ class TrialsDetail extends React.Component<TrialsDetailProps, TrialDetailState> ...@@ -336,12 +337,24 @@ class TrialsDetail extends React.Component<TrialsDetailProps, TrialDetailState>
let expLogCollection: boolean = false; let expLogCollection: boolean = false;
const isMultiy: boolean = res.data.params.multiPhase !== undefined const isMultiy: boolean = res.data.params.multiPhase !== undefined
? res.data.params.multiPhase : false; ? res.data.params.multiPhase : false;
const tuner = res.data.params.tuner;
// I'll set optimize is maximize if user not set optimize
let optimize: string = 'maximize';
if (tuner !== undefined) {
if (tuner.classArgs !== undefined) {
if (tuner.classArgs.optimize_mode !== undefined) {
if (tuner.classArgs.optimize_mode === 'minimize') {
optimize = 'minimize';
}
}
}
}
if (logCollection !== undefined && logCollection !== 'none') { if (logCollection !== undefined && logCollection !== 'none') {
expLogCollection = true; expLogCollection = true;
} }
if (this._isMounted) { if (this._isMounted) {
this.setState({ this.setState({
experimentPlatform: trainingPlatform, experimentInfo: { platform: trainingPlatform, optimizeMode: optimize },
searchSpace: res.data.params.searchSpace, searchSpace: res.data.params.searchSpace,
experimentLogCollection: expLogCollection, experimentLogCollection: expLogCollection,
isMultiPhase: isMultiy isMultiPhase: isMultiy
...@@ -380,7 +393,7 @@ class TrialsDetail extends React.Component<TrialsDetailProps, TrialDetailState> ...@@ -380,7 +393,7 @@ class TrialsDetail extends React.Component<TrialsDetailProps, TrialDetailState>
const { const {
tableListSource, searchResultSource, isHasSearch, isMultiPhase, tableListSource, searchResultSource, isHasSearch, isMultiPhase,
entriesTable, experimentPlatform, searchSpace, experimentLogCollection, entriesTable, experimentInfo, searchSpace, experimentLogCollection,
whichGraph, searchPlaceHolder whichGraph, searchPlaceHolder
} = this.state; } = this.state;
const source = isHasSearch ? searchResultSource : tableListSource; const source = isHasSearch ? searchResultSource : tableListSource;
...@@ -391,9 +404,10 @@ class TrialsDetail extends React.Component<TrialsDetailProps, TrialDetailState> ...@@ -391,9 +404,10 @@ class TrialsDetail extends React.Component<TrialsDetailProps, TrialDetailState>
<TabPane tab={this.titleOfacc} key="1"> <TabPane tab={this.titleOfacc} key="1">
<Row className="graph"> <Row className="graph">
<DefaultPoint <DefaultPoint
height={432} height={402}
showSource={source} showSource={source}
whichGraph={whichGraph} whichGraph={whichGraph}
optimize={experimentInfo.optimizeMode}
/> />
</Row> </Row>
</TabPane> </TabPane>
...@@ -465,7 +479,7 @@ class TrialsDetail extends React.Component<TrialsDetailProps, TrialDetailState> ...@@ -465,7 +479,7 @@ class TrialsDetail extends React.Component<TrialsDetailProps, TrialDetailState>
entries={entriesTable} entries={entriesTable}
tableSource={source} tableSource={source}
isMultiPhase={isMultiPhase} isMultiPhase={isMultiPhase}
platform={experimentPlatform} platform={experimentInfo.platform}
updateList={this.getDetailSource} updateList={this.getDetailSource}
logCollection={experimentLogCollection} logCollection={experimentLogCollection}
ref={(tabList) => this.tableList = tabList} ref={(tabList) => this.tableList = tabList}
......
import * as React from 'react'; import * as React from 'react';
import { Switch } from 'antd';
import ReactEcharts from 'echarts-for-react'; import ReactEcharts from 'echarts-for-react';
import { filterByStatus } from '../../static/function'; import { filterByStatus } from '../../static/function';
import { TableObj, DetailAccurPoint, TooltipForAccuracy } from '../../static/interface'; import { TableObj, DetailAccurPoint, TooltipForAccuracy } from '../../static/interface';
...@@ -10,32 +11,36 @@ interface DefaultPointProps { ...@@ -10,32 +11,36 @@ interface DefaultPointProps {
showSource: Array<TableObj>; showSource: Array<TableObj>;
height: number; height: number;
whichGraph: string; whichGraph: string;
optimize: string;
} }
interface DefaultPointState { interface DefaultPointState {
defaultSource: object; defaultSource: object;
accNodata: string; accNodata: string;
succeedTrials: number; succeedTrials: number;
isViewBestCurve: boolean;
} }
class DefaultPoint extends React.Component<DefaultPointProps, DefaultPointState> { class DefaultPoint extends React.Component<DefaultPointProps, DefaultPointState> {
public _isMounted = false; public _isDefaultMounted = false;
constructor(props: DefaultPointProps) { constructor(props: DefaultPointProps) {
super(props); super(props);
this.state = { this.state = {
defaultSource: {}, defaultSource: {},
accNodata: '', accNodata: '',
succeedTrials: 10000000 succeedTrials: 10000000,
isViewBestCurve: false
}; };
} }
defaultMetric = (succeedSource: Array<TableObj>) => { defaultMetric = (succeedSource: Array<TableObj>, isCurve: boolean) => {
const { optimize } = this.props;
const accSource: Array<DetailAccurPoint> = []; const accSource: Array<DetailAccurPoint> = [];
const showSource: Array<TableObj> = succeedSource.filter(filterByStatus); const showSource: Array<TableObj> = succeedSource.filter(filterByStatus);
const lengthOfSource = showSource.length; const lengthOfSource = showSource.length;
const tooltipDefault = lengthOfSource === 0 ? 'No data' : ''; const tooltipDefault = lengthOfSource === 0 ? 'No data' : '';
if (this._isMounted === true) { if (this._isDefaultMounted === true) {
this.setState(() => ({ this.setState(() => ({
succeedTrials: lengthOfSource, succeedTrials: lengthOfSource,
accNodata: tooltipDefault accNodata: tooltipDefault
...@@ -55,95 +60,195 @@ class DefaultPoint extends React.Component<DefaultPointProps, DefaultPointState> ...@@ -55,95 +60,195 @@ class DefaultPoint extends React.Component<DefaultPointProps, DefaultPointState>
type: 'value', type: 'value',
} }
}; };
if (this._isMounted === true) { if (this._isDefaultMounted === true) {
this.setState(() => ({ this.setState(() => ({
defaultSource: nullGraph defaultSource: nullGraph
})); }));
} }
} else { } else {
const resultList: Array<number | string>[] = []; const resultList: Array<number | object>[] = [];
const lineListDefault: Array<number> = [];
Object.keys(showSource).map(item => { Object.keys(showSource).map(item => {
const temp = showSource[item]; const temp = showSource[item];
if (temp.acc !== undefined) { if (temp.acc !== undefined) {
if (temp.acc.default !== undefined) { if (temp.acc.default !== undefined) {
const searchSpace = temp.description.parameters; const searchSpace = temp.description.parameters;
lineListDefault.push(temp.acc.default);
accSource.push({ accSource.push({
acc: temp.acc.default, acc: temp.acc.default,
index: temp.sequenceId, index: temp.sequenceId,
searchSpace: JSON.stringify(searchSpace) searchSpace: searchSpace
}); });
} }
} }
}); });
// deal with best metric line
const bestCurve: Array<number | object>[] = []; // best curve data source
bestCurve.push([0, lineListDefault[0], accSource[0].searchSpace]); // push the first value
if (optimize === 'maximize') {
for (let i = 1; i < lineListDefault.length; i++) {
const val = lineListDefault[i];
const latest = bestCurve[bestCurve.length - 1][1];
if (val >= latest) {
bestCurve.push([i, val, accSource[i].searchSpace]);
} else {
bestCurve.push([i, latest, accSource[i].searchSpace]);
}
}
} else {
for (let i = 1; i < lineListDefault.length; i++) {
const val = lineListDefault[i];
const latest = bestCurve[bestCurve.length - 1][1];
if (val <= latest) {
bestCurve.push([i, val, accSource[i].searchSpace]);
} else {
bestCurve.push([i, latest, accSource[i].searchSpace]);
}
}
}
Object.keys(accSource).map(item => { Object.keys(accSource).map(item => {
const items = accSource[item]; const items = accSource[item];
let temp: Array<number | string>; let temp: Array<number | object>;
temp = [items.index, items.acc, JSON.parse(items.searchSpace)]; temp = [items.index, items.acc, items.searchSpace];
resultList.push(temp); resultList.push(temp);
}); });
// isViewBestCurve: false show default metric graph
// isViewBestCurve: true show best curve
if (isCurve === true) {
if (this._isDefaultMounted === true) {
this.setState(() => ({
defaultSource: this.drawBestcurve(bestCurve, resultList)
}));
}
} else {
if (this._isDefaultMounted === true) {
this.setState(() => ({
defaultSource: this.drawDefaultMetric(resultList)
}));
}
}
}
}
const allAcuracy = { drawBestcurve = (realDefault: Array<number | object>[], resultList: Array<number | object>[]) => {
grid: { return {
left: '8%' grid: {
}, left: '8%'
tooltip: { },
trigger: 'item', tooltip: {
enterable: true, trigger: 'item',
position: function (point: Array<number>, data: TooltipForAccuracy) { enterable: true,
if (data.data[0] < resultList.length / 2) { position: function (point: Array<number>, data: TooltipForAccuracy) {
return [point[0], 80]; if (data.data[0] < realDefault.length / 2) {
} else { return [point[0], 80];
return [point[0] - 300, 80]; } else {
} return [point[0] - 300, 80];
},
formatter: function (data: TooltipForAccuracy) {
const result = '<div class="tooldetailAccuracy">' +
'<div>Trial No.: ' + data.data[0] + '</div>' +
'<div>Default metric: ' + data.data[1] + '</div>' +
'<div>Parameters: ' +
'<pre>' + JSON.stringify(data.data[2], null, 4) + '</pre>' +
'</div>' +
'</div>';
return result;
} }
}, },
xAxis: { formatter: function (data: TooltipForAccuracy) {
name: 'Trial', const result = '<div class="tooldetailAccuracy">' +
type: 'category', '<div>Trial No.: ' + data.data[0] + '</div>' +
}, '<div>Optimization curve: ' + data.data[1] + '</div>' +
yAxis: { '<div>Parameters: ' +
name: 'Default metric', '<pre>' + JSON.stringify(data.data[2], null, 4) + '</pre>' +
type: 'value', '</div>' +
scale: true '</div>';
return result;
}
},
xAxis: {
name: 'Trial',
type: 'category',
},
yAxis: {
name: 'Default metric',
type: 'value',
scale: true
},
series: [{
symbolSize: 6,
type: 'scatter',
data: resultList
}, {
type: 'line',
lineStyle: { color: '#FF6600' },
data: realDefault
}]
};
}
drawDefaultMetric = (resultList: Array<number | object>[]) => {
return {
grid: {
left: '8%'
},
tooltip: {
trigger: 'item',
enterable: true,
position: function (point: Array<number>, data: TooltipForAccuracy) {
if (data.data[0] < resultList.length / 2) {
return [point[0], 80];
} else {
return [point[0] - 300, 80];
}
}, },
series: [{ formatter: function (data: TooltipForAccuracy) {
symbolSize: 6, const result = '<div class="tooldetailAccuracy">' +
type: 'scatter', '<div>Trial No.: ' + data.data[0] + '</div>' +
data: resultList '<div>Default metric: ' + data.data[1] + '</div>' +
}] '<div>Parameters: ' +
}; '<pre>' + JSON.stringify(data.data[2], null, 4) + '</pre>' +
if (this._isMounted === true) { '</div>' +
this.setState(() => ({ '</div>';
defaultSource: allAcuracy return result;
})); }
} },
xAxis: {
name: 'Trial',
type: 'category',
},
yAxis: {
name: 'Default metric',
type: 'value',
scale: true
},
series: [{
symbolSize: 6,
type: 'scatter',
data: resultList
}]
};
}
loadDefault = (checked: boolean) => {
// checked: true show best metric curve
const { showSource } = this.props;
if (this._isDefaultMounted === true) {
this.defaultMetric(showSource, checked);
// ** deal with data and then update view layer
this.setState(() => ({ isViewBestCurve: checked }));
} }
} }
// update parent component state // update parent component state
componentWillReceiveProps(nextProps: DefaultPointProps) { componentWillReceiveProps(nextProps: DefaultPointProps) {
const { whichGraph, showSource } = nextProps; const { whichGraph, showSource } = nextProps;
const { isViewBestCurve } = this.state;
if (whichGraph === '1') { if (whichGraph === '1') {
this.defaultMetric(showSource); this.defaultMetric(showSource, isViewBestCurve);
} }
} }
shouldComponentUpdate(nextProps: DefaultPointProps, nextState: DefaultPointState) { shouldComponentUpdate(nextProps: DefaultPointProps, nextState: DefaultPointState) {
const { whichGraph } = nextProps; const { whichGraph } = nextProps;
const succTrial = this.state.succeedTrials;
const { succeedTrials } = nextState;
if (whichGraph === '1') { if (whichGraph === '1') {
const { succeedTrials, isViewBestCurve } = nextState;
const succTrial = this.state.succeedTrials;
const isViewBestCurveBefore = this.state.isViewBestCurve;
if (isViewBestCurveBefore !== isViewBestCurve) {
return true;
}
if (succeedTrials !== succTrial) { if (succeedTrials !== succTrial) {
return true; return true;
} }
...@@ -153,11 +258,11 @@ class DefaultPoint extends React.Component<DefaultPointProps, DefaultPointState> ...@@ -153,11 +258,11 @@ class DefaultPoint extends React.Component<DefaultPointProps, DefaultPointState>
} }
componentDidMount() { componentDidMount() {
this._isMounted = true; this._isDefaultMounted = true;
} }
componentWillUnmount() { componentWillUnmount() {
this._isMounted = false; this._isDefaultMounted = false;
} }
render() { render() {
...@@ -165,6 +270,12 @@ class DefaultPoint extends React.Component<DefaultPointProps, DefaultPointState> ...@@ -165,6 +270,12 @@ class DefaultPoint extends React.Component<DefaultPointProps, DefaultPointState>
const { defaultSource, accNodata } = this.state; const { defaultSource, accNodata } = this.state;
return ( return (
<div> <div>
<div className="default-metric">
<div className="position">
<span className="bold">optimization curve</span>
<Switch defaultChecked={false} onChange={this.loadDefault} />
</div>
</div>
<ReactEcharts <ReactEcharts
option={defaultSource} option={defaultSource}
style={{ style={{
...@@ -174,7 +285,6 @@ class DefaultPoint extends React.Component<DefaultPointProps, DefaultPointState> ...@@ -174,7 +285,6 @@ class DefaultPoint extends React.Component<DefaultPointProps, DefaultPointState>
}} }}
theme="my_theme" theme="my_theme"
notMerge={true} // update now notMerge={true} // update now
// lazyUpdate={true}
/> />
<div className="showMess">{accNodata}</div> <div className="showMess">{accNodata}</div>
</div> </div>
......
import * as React from 'react'; import * as React from 'react';
import { Row, Col, Button, Switch } from 'antd'; import { Row, Button, Switch } from 'antd';
import { TooltipForIntermediate, TableObj, Intermedia } from '../../static/interface'; import { TooltipForIntermediate, TableObj, Intermedia } from '../../static/interface';
import ReactEcharts from 'echarts-for-react'; import ReactEcharts from 'echarts-for-react';
require('echarts/lib/component/tooltip'); require('echarts/lib/component/tooltip');
...@@ -108,13 +108,13 @@ class Intermediate extends React.Component<IntermediateProps, IntermediateState> ...@@ -108,13 +108,13 @@ class Intermediate extends React.Component<IntermediateProps, IntermediateState>
}, },
xAxis: { xAxis: {
type: 'category', type: 'category',
name: 'Step', // name: '# Intermediate',
boundaryGap: false, boundaryGap: false,
data: xAxis data: xAxis
}, },
yAxis: { yAxis: {
type: 'value', type: 'value',
name: 'metric' name: 'Metric'
}, },
series: trialIntermediate series: trialIntermediate
}; };
...@@ -136,7 +136,7 @@ class Intermediate extends React.Component<IntermediateProps, IntermediateState> ...@@ -136,7 +136,7 @@ class Intermediate extends React.Component<IntermediateProps, IntermediateState>
}, },
yAxis: { yAxis: {
type: 'value', type: 'value',
name: 'metric' name: 'Metric'
} }
}; };
if (this._isMounted) { if (this._isMounted) {
...@@ -282,58 +282,52 @@ class Intermediate extends React.Component<IntermediateProps, IntermediateState> ...@@ -282,58 +282,52 @@ class Intermediate extends React.Component<IntermediateProps, IntermediateState>
<div> <div>
{/* style in para.scss */} {/* style in para.scss */}
<Row className="meline intermediate"> <Row className="meline intermediate">
<Col span={8} /> {/* filter message */}
<Col span={3} style={{ height: 34 }}> <span>Filter</span>
{/* filter message */} <Switch
<span>filter</span> defaultChecked={false}
<Switch onChange={this.switchTurn}
defaultChecked={false} />
onChange={this.switchTurn}
/>
</Col>
{ {
isFilter isFilter
? ?
<div> <span>
<Col span={3}> <span className="filter-x"># Intermeidate</span>
<span>Step</span> <input
<input // placeholder="point"
placeholder="point" ref={input => this.pointInput = input}
ref={input => this.pointInput = input} className="strange"
className="strange" />
/> <span>Metric range</span>
</Col> <input
<Col className="range" span={10}> // placeholder="range"
<span>Intermediate result</span> ref={input => this.minValInput = input}
<input />
placeholder="number" <span className="hyphen">-</span>
ref={input => this.minValInput = input} <input
/> // placeholder="range"
<span className="heng">-</span> ref={input => this.maxValInput = input}
<input />
placeholder="number" <Button
ref={input => this.maxValInput = input} type="primary"
/> className="changeBtu tableButton"
<Button onClick={this.filterLines}
type="primary" disabled={isLoadconfirmBtn}
className="changeBtu tableButton" >
onClick={this.filterLines} Confirm
disabled={isLoadconfirmBtn} </Button>
> </span>
Confirm
</Button>
</Col>
</div>
: :
<Col /> null
} }
</Row> </Row>
<Row> <Row className="intermeidate-graph">
<ReactEcharts <ReactEcharts
option={interSource} option={interSource}
style={{ width: '100%', height: 418, margin: '0 auto' }} style={{ width: '100%', height: 418, margin: '0 auto' }}
notMerge={true} // update now notMerge={true} // update now
/> />
<div className="yAxis"># Intermediate</div>
</Row> </Row>
</div> </div>
); );
......
...@@ -87,13 +87,10 @@ class Para extends React.Component<ParaProps, ParaState> { ...@@ -87,13 +87,10 @@ class Para extends React.Component<ParaProps, ParaState> {
let temp: Array<number> = []; let temp: Array<number> = [];
for (let i = 0; i < dimName.length; i++) { for (let i = 0; i < dimName.length; i++) {
if ('type' in parallelAxis[i]) { if ('type' in parallelAxis[i]) {
temp.push( temp.push(eachTrialParams[item][dimName[i]].toString());
eachTrialParams[item][dimName[i]].toString()
);
} else { } else {
temp.push( // default metric
eachTrialParams[item][dimName[i]] temp.push(eachTrialParams[item][dimName[i]]);
);
} }
} }
paraYdata.push(temp); paraYdata.push(temp);
...@@ -199,11 +196,18 @@ class Para extends React.Component<ParaProps, ParaState> { ...@@ -199,11 +196,18 @@ class Para extends React.Component<ParaProps, ParaState> {
break; break;
// support log distribute // support log distribute
case 'loguniform': case 'loguniform':
parallelAxis.push({ if (lenOfDataSource > 1) {
dim: i, parallelAxis.push({
name: dimName[i], dim: i,
type: 'log', name: dimName[i],
}); type: 'log',
});
} else {
parallelAxis.push({
dim: i,
name: dimName[i]
});
}
break; break;
default: default:
......
import * as React from 'react'; import * as React from 'react';
import axios from 'axios'; import axios from 'axios';
import ReactEcharts from 'echarts-for-react'; import ReactEcharts from 'echarts-for-react';
import { Row, Table, Button, Popconfirm, Modal, Checkbox, Select } from 'antd'; import { Row, Table, Button, Popconfirm, Modal, Checkbox, Select, Icon } from 'antd';
const Option = Select.Option; const Option = Select.Option;
const CheckboxGroup = Checkbox.Group; const CheckboxGroup = Checkbox.Group;
import { MANAGER_IP, trialJobStatus, COLUMN, COLUMN_INDEX } from '../../static/const'; import { MANAGER_IP, trialJobStatus, COLUMN, COLUMN_INDEX, COLUMNPro } from '../../static/const';
import { convertDuration, intermediateGraphOption, killJob } from '../../static/function'; import { convertDuration, intermediateGraphOption, killJob, filterByStatus } from '../../static/function';
import { TableObj, TrialJob } from '../../static/interface'; import { TableObj, TrialJob } from '../../static/interface';
import OpenRow from '../public-child/OpenRow'; import OpenRow from '../public-child/OpenRow';
import Compare from '../Modal/Compare'; import Compare from '../Modal/Compare';
...@@ -180,7 +180,8 @@ class TableList extends React.Component<TableListProps, TableListState> { ...@@ -180,7 +180,8 @@ class TableList extends React.Component<TableListProps, TableListState> {
// checkbox for coloumn // checkbox for coloumn
selectedColumn = (checkedValues: Array<string>) => { selectedColumn = (checkedValues: Array<string>) => {
let count = 6; // 7: because have seven common column, "Intermediate count" is not shown by default
let count = 7;
const want: Array<object> = []; const want: Array<object> = [];
const finalKeys: Array<string> = []; const finalKeys: Array<string> = [];
const wantResult: Array<string> = []; const wantResult: Array<string> = [];
...@@ -192,7 +193,7 @@ class TableList extends React.Component<TableListProps, TableListState> { ...@@ -192,7 +193,7 @@ class TableList extends React.Component<TableListProps, TableListState> {
case 'Status': case 'Status':
case 'Operation': case 'Operation':
case 'Default': case 'Default':
case 'Intermediate result': case 'Intermeidate count':
break; break;
default: default:
finalKeys.push(checkedValues[m]); finalKeys.push(checkedValues[m]);
...@@ -285,23 +286,27 @@ class TableList extends React.Component<TableListProps, TableListState> { ...@@ -285,23 +286,27 @@ class TableList extends React.Component<TableListProps, TableListState> {
this.fillSelectedRowsTostate(selected, selectedRows); this.fillSelectedRowsTostate(selected, selectedRows);
} }
}; };
let showTitle = COLUMN; let showTitle = COLUMNPro;
let bgColor = ''; let bgColor = '';
const trialJob: Array<TrialJob> = []; const trialJob: Array<TrialJob> = [];
const showColumn: Array<object> = []; const showColumn: Array<object> = [];
if (tableSource.length >= 1) { // only succeed trials have final keys
const temp = tableSource[0].acc; if (tableSource.filter(filterByStatus).length >= 1) {
const temp = tableSource.filter(filterByStatus)[0].acc;
if (temp !== undefined && typeof temp === 'object') { if (temp !== undefined && typeof temp === 'object') {
if (this._isMounted) { if (this._isMounted) {
// concat default column and finalkeys // concat default column and finalkeys
const item = Object.keys(temp); const item = Object.keys(temp);
const want: Array<string> = []; // item: ['default', 'other-keys', 'maybe loss']
Object.keys(item).map(key => { if (item.length > 1) {
if (item[key] !== 'default') { const want: Array<string> = [];
want.push(item[key]); item.forEach(value => {
} if (value !== 'default') {
}); want.push(value);
showTitle = COLUMN.concat(want); }
});
showTitle = COLUMNPro.concat(want);
}
} }
} }
} }
...@@ -345,7 +350,7 @@ class TableList extends React.Component<TableListProps, TableListState> { ...@@ -345,7 +350,7 @@ class TableList extends React.Component<TableListProps, TableListState> {
title: 'Duration', title: 'Duration',
dataIndex: 'duration', dataIndex: 'duration',
key: 'duration', key: 'duration',
width: 140, width: 100,
// the sort of number // the sort of number
sorter: (a: TableObj, b: TableObj) => (a.duration as number) - (b.duration as number), sorter: (a: TableObj, b: TableObj) => (a.duration as number) - (b.duration as number),
render: (text: string, record: TableObj) => { render: (text: string, record: TableObj) => {
...@@ -387,6 +392,19 @@ class TableList extends React.Component<TableListProps, TableListState> { ...@@ -387,6 +392,19 @@ class TableList extends React.Component<TableListProps, TableListState> {
sorter: (a: TableObj, b: TableObj): number => a.status.localeCompare(b.status) sorter: (a: TableObj, b: TableObj): number => a.status.localeCompare(b.status)
}); });
break; break;
case 'Intermeidate count':
showColumn.push({
title: 'Intermediate count',
dataIndex: 'progress',
key: 'progress',
width: 86,
render: (text: string, record: TableObj) => {
return (
<span>{`#${record.description.intermediate.length}`}</span>
);
},
});
break;
case 'Default': case 'Default':
showColumn.push({ showColumn.push({
title: 'Default metric', title: 'Default metric',
...@@ -415,37 +433,37 @@ class TableList extends React.Component<TableListProps, TableListState> { ...@@ -415,37 +433,37 @@ class TableList extends React.Component<TableListProps, TableListState> {
title: 'Operation', title: 'Operation',
dataIndex: 'operation', dataIndex: 'operation',
key: 'operation', key: 'operation',
width: 90, width: 120,
render: (text: string, record: TableObj) => { render: (text: string, record: TableObj) => {
let trialStatus = record.status; let trialStatus = record.status;
let flagKill = false; const flag: boolean = (trialStatus === 'RUNNING') ? false : true;
if (trialStatus === 'RUNNING') {
flagKill = true;
} else {
flagKill = false;
}
return ( return (
flagKill <Row id="detail-button">
? {/* see intermediate result graph */}
( <Button
<Popconfirm type="primary"
title="Are you sure to cancel this trial?" className="common-style"
onConfirm={killJob. onClick={this.showIntermediateModal.bind(this, record.id)}
bind(this, record.key, record.id, record.status, updateList)} title="Intermediate"
> >
<Button type="primary" className="tableButton">Kill</Button> <Icon type="line-chart" />
</Popconfirm> </Button>
) {/* kill job */}
: <Popconfirm
( title="Are you sure to cancel this trial?"
onConfirm={killJob.
bind(this, record.key, record.id, record.status, updateList)}
>
<Button <Button
type="primary" type="default"
className="tableButton" disabled={flag}
disabled={true} className="margin-mediate special"
title="kill"
> >
Kill <Icon type="stop" />
</Button> </Button>
) </Popconfirm>
</Row>
); );
}, },
}); });
......
...@@ -42,20 +42,23 @@ const COLUMN_INDEX = [ ...@@ -42,20 +42,23 @@ const COLUMN_INDEX = [
index: 4 index: 4
}, },
{ {
name: 'Default', name: 'Intermeidate count',
index: 5 index: 5
}, },
{ {
name: 'Operation', name: 'Default',
index: 10000 index: 6
}, },
{ {
name: 'Intermediate result', name: 'Operation',
index: 10001 index: 10000
} }
]; ];
const COLUMN = ['Trial No.', 'ID', 'Duration', 'Status', 'Default', 'Operation', 'Intermediate result']; // defatult selected column
const COLUMN = ['Trial No.', 'ID', 'Duration', 'Status', 'Default', 'Operation'];
// all choice column !dictory final
const COLUMNPro = ['Trial No.', 'ID', 'Duration', 'Status', 'Intermeidate count', 'Default', 'Operation'];
export { export {
MANAGER_IP, DOWNLOAD_IP, trialJobStatus, MANAGER_IP, DOWNLOAD_IP, trialJobStatus, COLUMNPro,
CONTROLTYPE, MONACO, COLUMN, COLUMN_INDEX, DRAWEROPTION CONTROLTYPE, MONACO, COLUMN, COLUMN_INDEX, DRAWEROPTION
}; };
...@@ -59,7 +59,7 @@ interface AccurPoint { ...@@ -59,7 +59,7 @@ interface AccurPoint {
interface DetailAccurPoint { interface DetailAccurPoint {
acc: number; acc: number;
index: number; index: number;
searchSpace: string; searchSpace: object;
} }
interface TooltipForIntermediate { interface TooltipForIntermediate {
...@@ -117,8 +117,13 @@ interface Intermedia { ...@@ -117,8 +117,13 @@ interface Intermedia {
hyperPara: object; // each trial hyperpara value hyperPara: object; // each trial hyperpara value
} }
interface ExperimentInfo {
platform: string;
optimizeMode: string;
}
export { export {
TableObj, Parameters, Experiment, AccurPoint, TrialNumber, TrialJob, TableObj, Parameters, Experiment, AccurPoint, TrialNumber, TrialJob,
DetailAccurPoint, TooltipForAccuracy, ParaObj, Dimobj, FinalResult, FinalType, DetailAccurPoint, TooltipForAccuracy, ParaObj, Dimobj, FinalResult, FinalType,
TooltipForIntermediate, SearchSpace, Intermedia TooltipForIntermediate, SearchSpace, Intermedia, ExperimentInfo
}; };
...@@ -23,3 +23,13 @@ ...@@ -23,3 +23,13 @@
font-weight: 600; font-weight: 600;
} }
} }
.compare-intermeidate{
position: relative;
.compare-yAxis{
color: #333;
position: absolute;
top: 87%;
left: 45%;
}
}
...@@ -27,19 +27,23 @@ ...@@ -27,19 +27,23 @@
/* Intermediate Result Style */ /* Intermediate Result Style */
.intermediate{ .intermediate{
width: 90%;
text-align: right;
/* border: 1px solid blue; */ /* border: 1px solid blue; */
input{ input{
width: 80px; width: 64px;
height: 32px; height: 32px;
padding-left: 8px; padding-left: 8px;
} }
.strange{ .strange{
margin-top: 2px; margin-top: 2px;
margin-right: 15px;
} }
.range{ .hyphen{
.heng{ margin-left: 6px;
margin-left: 6px; margin-right: 6px;
margin-right: 6px; }
} .filter-x{
margin-left: 15px;
} }
} }
...@@ -9,6 +9,7 @@ ...@@ -9,6 +9,7 @@
} }
.probar{ .probar{
width: 95%;
height: 34px; height: 34px;
margin-top: 15px; margin-top: 15px;
......
...@@ -31,14 +31,12 @@ ...@@ -31,14 +31,12 @@
text-align: center; text-align: center;
color:#212121; color:#212121;
font-size: 14px; font-size: 14px;
/* background-color: #f2f2f2; */
} }
th{ th{
padding: 2px; padding: 2px;
background-color:white !important; background-color:white !important;
font-size: 14px; font-size: 14px;
color: #808080; color: #808080;
border-bottom: 1px solid #d0d0d0;
text-align: center; text-align: center;
} }
...@@ -105,3 +103,51 @@ ...@@ -105,3 +103,51 @@
.ant-table-selection{ .ant-table-selection{
display: none; display: none;
} }
/* fix the border-bottom bug in firefox and edge */
.ant-table-thead > tr > th .ant-table-column-sorters::before{
padding-bottom: 25px;
border-bottom: 1px solid #e8e8e8;
}
.margin-mediate{
margin: 0 10px;
}
#detail-button{
.common-style, .common-style:visited, .common-style:focus{
height: 26px;
border: none;
border-radius: 0;
background-color: #0078d4;
}
.common-style:hover{
background-color: #106ebe;
}
.common-style:active{
background-color: #005a9e;
outline: 0;
}
.common-style:disabled{
background-color: #f4f4f4;
}
.special, .special:visited, .special:focus{
height: 26px;
border: none;
border-radius: 0;
outline: 0;
background-color: #f4f4f4;
color: #333;
}
.special:hover{
background-color: #eaeaea;
}
.special:active{
background-color: #c8c8c8;
outline: 0;
}
.special:disabled{
background-color: #f4f4f4;
color: #d9d9d9;
}
}
...@@ -70,3 +70,27 @@ ...@@ -70,3 +70,27 @@
.allList{ .allList{
margin-top: 15px; margin-top: 15px;
} }
.default-metric{
width: 90%;
text-align: right;
margin-top: 15px;
.position{
color: #333;
.bold{
font-weight: 600;
margin-right: 10px;
}
}
}
/* for # intermediate in intermeidate graph*/
.intermeidate-graph{
position: relative;
.yAxis{
color: #333;
position: absolute;
left: 45%;
top: 86%;
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment