Unverified Commit bc55eec6 authored by liuzhe-lz's avatar liuzhe-lz Committed by GitHub
Browse files

Align `nni.experiment` tuner behavior with nnictl (#3419)

parent f8978297
...@@ -13,59 +13,50 @@ Since ``nni v2.0``, we provide a new way to launch experiments. Before that, you ...@@ -13,59 +13,50 @@ Since ``nni v2.0``, we provide a new way to launch experiments. Before that, you
Run a New Experiment Run a New Experiment
-------------------- --------------------
After successfully installing ``nni``, you can start the experiment with a python script in the following 3 steps. After successfully installing ``nni``, you can start the experiment with a python script in the following 2 steps.
.. ..
Step 1 - Initialize a tuner you want to use Step 1 - Initialize an experiment instance and configure it
.. code-block:: python .. code-block:: python
from nni.algorithms.hpo.hyperopt_tuner import HyperoptTuner from nni.experiment import Experiment
tuner = HyperoptTuner('tpe') experiment = Experiment('local')
Very simple, you have successfully initialized a ``HyperoptTuner`` instance called ``tuner``.
See all real `builtin tuners <../builtin_tuner.rst>`__ supported in NNI.
..
Step 2 - Initialize an experiment instance and configure it
.. code-block:: python
experiment = Experiment(tuner=tuner, training_service='local')
Now, you have a ``Experiment`` instance with ``tuner`` you have initialized in the previous step, and this experiment will launch trials on your local machine due to ``training_service='local'``. Now, you have a ``Experiment`` instance, and this experiment will launch trials on your local machine due to ``training_service='local'``.
See all `training services <../training_services.rst>`__ supported in NNI. See all `training services <../training_services.rst>`__ supported in NNI.
.. code-block:: python .. code-block:: python
experiment.config.experiment_name = 'test' experiment.config.experiment_name = 'MNIST example'
experiment.config.trial_concurrency = 2 experiment.config.trial_concurrency = 2
experiment.config.max_trial_number = 5 experiment.config.max_trial_number = 10
experiment.config.search_space = search_space experiment.config.search_space = search_space
experiment.config.trial_command = 'python3 mnist.py' experiment.config.trial_command = 'python3 mnist.py'
experiment.config.trial_code_directory = Path(__file__).parent experiment.config.trial_code_directory = Path(__file__).parent
experiment.config.tuner.name = 'TPE'
experiment.config.tuner.class_args['optimize_mode'] = 'maximize'
experiment.config.training_service.use_active_gpu = True experiment.config.training_service.use_active_gpu = True
Use the form like ``experiment.config.foo = 'bar'`` to configure your experiment. Use the form like ``experiment.config.foo = 'bar'`` to configure your experiment.
See all real `builtin tuners <../builtin_tuner.rst>`__ supported in NNI.
See `parameter configuration <../reference/experiment_config.rst>`__ required by different training services. See `parameter configuration <../reference/experiment_config.rst>`__ required by different training services.
.. ..
Step 3 - Just run Step 2 - Just run
.. code-block:: python .. code-block:: python
experiment.run(port=8081) experiment.run(port=8080)
Now, you have successfully launched an NNI experiment. And you can type ``localhost:8081`` in your browser to observe your experiment in real time. Now, you have successfully launched an NNI experiment. And you can type ``localhost:8080`` in your browser to observe your experiment in real time.
.. Note:: In this way, experiment will run in the foreground and will automatically exit when the experiment finished. If you want to run an experiment in an interactive way, use ``start()`` in Step 3. .. Note:: In this way, experiment will run in the foreground and will automatically exit when the experiment finished. If you want to run an experiment in an interactive way, use ``start()`` in Step 2.
Example Example
^^^^^^^ ^^^^^^^
...@@ -74,10 +65,8 @@ Below is an example for this new launching approach. You can also find this code ...@@ -74,10 +65,8 @@ Below is an example for this new launching approach. You can also find this code
.. code-block:: python .. code-block:: python
from pathlib import Path from pathlib import Path
from nni.experiment import Experiment
from nni.algorithms.hpo.hyperopt_tuner import HyperoptTuner
tuner = HyperoptTuner('tpe') from nni.experiment import Experiment
search_space = { search_space = {
"dropout_rate": { "_type": "uniform", "_value": [0.5, 0.9] }, "dropout_rate": { "_type": "uniform", "_value": [0.5, 0.9] },
...@@ -87,16 +76,18 @@ Below is an example for this new launching approach. You can also find this code ...@@ -87,16 +76,18 @@ Below is an example for this new launching approach. You can also find this code
"learning_rate": { "_type": "choice", "_value": [0.0001, 0.001, 0.01, 0.1] } "learning_rate": { "_type": "choice", "_value": [0.0001, 0.001, 0.01, 0.1] }
} }
experiment = Experiment(tuner, 'local') experiment = Experiment('local')
experiment.config.experiment_name = 'test' experiment.config.experiment_name = 'MNIST example'
experiment.config.trial_concurrency = 2 experiment.config.trial_concurrency = 2
experiment.config.max_trial_number = 5 experiment.config.max_trial_number = 10
experiment.config.search_space = search_space experiment.config.search_space = search_space
experiment.config.trial_command = 'python3 mnist.py' experiment.config.trial_command = 'python3 mnist.py'
experiment.config.trial_code_directory = Path(__file__).parent experiment.config.trial_code_directory = Path(__file__).parent
experiment.config.tuner.name = 'TPE'
experiment.config.tuner.class_args['optimize_mode'] = 'maximize'
experiment.config.training_service.use_active_gpu = True experiment.config.training_service.use_active_gpu = True
experiment.run(8081) experiment.run(8080)
Start and Manage a New Experiment Start and Manage a New Experiment
--------------------------------- ---------------------------------
......
...@@ -26,8 +26,7 @@ ...@@ -26,8 +26,7 @@
"name": "stdout", "name": "stdout",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"[2021-02-25 07:50:38] Tuner not set, wait for connect...\n", "[2021-03-05 12:18:28] Connect to port 8080 success, experiment id is DH8pVfXc, status is RUNNING.\n"
"[2021-02-25 07:50:38] Connect to port 8080 success, experiment id is IF0JnfLE, status is RUNNING.\n"
] ]
} }
], ],
...@@ -53,27 +52,27 @@ ...@@ -53,27 +52,27 @@
{ {
"data": { "data": {
"text/plain": [ "text/plain": [
"{'id': 'IF0JnfLE',\n", "{'id': 'DH8pVfXc',\n",
" 'revision': 6,\n", " 'revision': 4,\n",
" 'execDuration': 28,\n", " 'execDuration': 10,\n",
" 'logDir': '/home/ningshang/nni-experiments/IF0JnfLE',\n", " 'logDir': '/home/ningshang/nni-experiments/DH8pVfXc',\n",
" 'nextSequenceId': 2,\n", " 'nextSequenceId': 1,\n",
" 'params': {'authorName': 'default',\n", " 'params': {'authorName': 'default',\n",
" 'experimentName': 'example_sklearn-classification',\n", " 'experimentName': 'example_sklearn-classification',\n",
" 'trialConcurrency': 1,\n", " 'trialConcurrency': 1,\n",
" 'maxExecDuration': 3600,\n", " 'maxExecDuration': 3600,\n",
" 'maxTrialNum': 5,\n", " 'maxTrialNum': 100,\n",
" 'searchSpace': '{\"C\": {\"_type\": \"uniform\", \"_value\": [0.1, 1]}, \"kernel\": {\"_type\": \"choice\", \"_value\": [\"linear\", \"rbf\", \"poly\", \"sigmoid\"]}, \"degree\": {\"_type\": \"choice\", \"_value\": [1, 2, 3, 4]}, \"gamma\": {\"_type\": \"uniform\", \"_value\": [0.01, 0.1]}, \"coef0\": {\"_type\": \"uniform\", \"_value\": [0.01, 0.1]}}',\n", " 'searchSpace': '{\"C\": {\"_type\": \"uniform\", \"_value\": [0.1, 1]}, \"kernel\": {\"_type\": \"choice\", \"_value\": [\"linear\", \"rbf\", \"poly\", \"sigmoid\"]}, \"degree\": {\"_type\": \"choice\", \"_value\": [1, 2, 3, 4]}, \"gamma\": {\"_type\": \"uniform\", \"_value\": [0.01, 0.1]}, \"coef0\": {\"_type\": \"uniform\", \"_value\": [0.01, 0.1]}}',\n",
" 'trainingServicePlatform': 'local',\n", " 'trainingServicePlatform': 'local',\n",
" 'tuner': {'builtinTunerName': 'TPE',\n", " 'tuner': {'builtinTunerName': 'TPE',\n",
" 'classArgs': {'optimize_mode': 'maximize'},\n", " 'classArgs': {'optimize_mode': 'maximize'},\n",
" 'checkpointDir': '/home/ningshang/nni-experiments/IF0JnfLE/checkpoint'},\n", " 'checkpointDir': '/home/ningshang/nni-experiments/DH8pVfXc/checkpoint'},\n",
" 'versionCheck': True,\n", " 'versionCheck': True,\n",
" 'clusterMetaData': [{'key': 'trial_config',\n", " 'clusterMetaData': [{'key': 'trial_config',\n",
" 'value': {'command': 'python3 main.py',\n", " 'value': {'command': 'python3 main.py',\n",
" 'codeDir': '/home/ningshang/nni/examples/trials/sklearn/classification/.',\n", " 'codeDir': '/home/ningshang/nni/examples/trials/sklearn/classification/.',\n",
" 'gpuNum': 0}}]},\n", " 'gpuNum': 0}}]},\n",
" 'startTime': 1614239412494}" " 'startTime': 1614946699989}"
] ]
}, },
"execution_count": 2, "execution_count": 2,
...@@ -90,9 +89,17 @@ ...@@ -90,9 +89,17 @@
"execution_count": 3, "execution_count": 3,
"id": "printable-bookmark", "id": "printable-bookmark",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[2021-03-05 12:18:32] (root) Successfully update maxTrialNum.\n"
]
}
],
"source": [ "source": [
"experiment.update_max_trial_number(10)" "experiment.update_max_trial_number(200)"
] ]
}, },
{ {
...@@ -104,27 +111,27 @@ ...@@ -104,27 +111,27 @@
{ {
"data": { "data": {
"text/plain": [ "text/plain": [
"{'id': 'IF0JnfLE',\n", "{'id': 'DH8pVfXc',\n",
" 'revision': 8,\n", " 'revision': 5,\n",
" 'execDuration': 32,\n", " 'execDuration': 14,\n",
" 'logDir': '/home/ningshang/nni-experiments/IF0JnfLE',\n", " 'logDir': '/home/ningshang/nni-experiments/DH8pVfXc',\n",
" 'nextSequenceId': 2,\n", " 'nextSequenceId': 1,\n",
" 'params': {'authorName': 'default',\n", " 'params': {'authorName': 'default',\n",
" 'experimentName': 'example_sklearn-classification',\n", " 'experimentName': 'example_sklearn-classification',\n",
" 'trialConcurrency': 1,\n", " 'trialConcurrency': 1,\n",
" 'maxExecDuration': 3600,\n", " 'maxExecDuration': 3600,\n",
" 'maxTrialNum': 10,\n", " 'maxTrialNum': 200,\n",
" 'searchSpace': '{\"C\": {\"_type\": \"uniform\", \"_value\": [0.1, 1]}, \"kernel\": {\"_type\": \"choice\", \"_value\": [\"linear\", \"rbf\", \"poly\", \"sigmoid\"]}, \"degree\": {\"_type\": \"choice\", \"_value\": [1, 2, 3, 4]}, \"gamma\": {\"_type\": \"uniform\", \"_value\": [0.01, 0.1]}, \"coef0\": {\"_type\": \"uniform\", \"_value\": [0.01, 0.1]}}',\n", " 'searchSpace': '{\"C\": {\"_type\": \"uniform\", \"_value\": [0.1, 1]}, \"kernel\": {\"_type\": \"choice\", \"_value\": [\"linear\", \"rbf\", \"poly\", \"sigmoid\"]}, \"degree\": {\"_type\": \"choice\", \"_value\": [1, 2, 3, 4]}, \"gamma\": {\"_type\": \"uniform\", \"_value\": [0.01, 0.1]}, \"coef0\": {\"_type\": \"uniform\", \"_value\": [0.01, 0.1]}}',\n",
" 'trainingServicePlatform': 'local',\n", " 'trainingServicePlatform': 'local',\n",
" 'tuner': {'builtinTunerName': 'TPE',\n", " 'tuner': {'builtinTunerName': 'TPE',\n",
" 'classArgs': {'optimize_mode': 'maximize'},\n", " 'classArgs': {'optimize_mode': 'maximize'},\n",
" 'checkpointDir': '/home/ningshang/nni-experiments/IF0JnfLE/checkpoint'},\n", " 'checkpointDir': '/home/ningshang/nni-experiments/DH8pVfXc/checkpoint'},\n",
" 'versionCheck': True,\n", " 'versionCheck': True,\n",
" 'clusterMetaData': [{'key': 'trial_config',\n", " 'clusterMetaData': [{'key': 'trial_config',\n",
" 'value': {'command': 'python3 main.py',\n", " 'value': {'command': 'python3 main.py',\n",
" 'codeDir': '/home/ningshang/nni/examples/trials/sklearn/classification/.',\n", " 'codeDir': '/home/ningshang/nni/examples/trials/sklearn/classification/.',\n",
" 'gpuNum': 0}}]},\n", " 'gpuNum': 0}}]},\n",
" 'startTime': 1614239412494}" " 'startTime': 1614946699989}"
] ]
}, },
"execution_count": 4, "execution_count": 4,
...@@ -154,8 +161,8 @@ ...@@ -154,8 +161,8 @@
"name": "stdout", "name": "stdout",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"[2021-02-25 07:50:49] Stopping experiment, please wait...\n", "[2021-03-05 12:18:36] Stopping experiment, please wait...\n",
"[2021-02-25 07:50:49] Experiment stopped\n" "[2021-03-05 12:18:38] Experiment stopped\n"
] ]
} }
], ],
......
...@@ -8,36 +8,17 @@ ...@@ -8,36 +8,17 @@
"## Start and Manage a New Experiment" "## Start and Manage a New Experiment"
] ]
}, },
{
"cell_type": "markdown",
"id": "immediate-daily",
"metadata": {},
"source": [
"### 1. Initialize Tuner"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "formed-grounds",
"metadata": {},
"outputs": [],
"source": [
"from nni.algorithms.hpo.gridsearch_tuner import GridSearchTuner\n",
"tuner = GridSearchTuner()"
]
},
{ {
"cell_type": "markdown", "cell_type": "markdown",
"id": "reported-somerset", "id": "reported-somerset",
"metadata": {}, "metadata": {},
"source": [ "source": [
"### 2. Configure Search Space" "### 1. Configure Search Space"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 2, "execution_count": 1,
"id": "potential-williams", "id": "potential-williams",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
...@@ -56,24 +37,27 @@ ...@@ -56,24 +37,27 @@
"id": "greek-archive", "id": "greek-archive",
"metadata": {}, "metadata": {},
"source": [ "source": [
"### 3. Configure Experiment " "### 2. Configure Experiment "
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 3, "execution_count": 2,
"id": "fiscal-expansion", "id": "fiscal-expansion",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"from nni.experiment import Experiment\n", "from nni.experiment import Experiment\n",
"experiment = Experiment(tuner, 'local')\n", "experiment = Experiment('local')\n",
"experiment.config.experiment_name = 'test'\n", "experiment.config.experiment_name = 'Example'\n",
"experiment.config.trial_concurrency = 2\n", "experiment.config.trial_concurrency = 2\n",
"experiment.config.max_trial_number = 5\n", "experiment.config.max_trial_number = 10\n",
"experiment.config.search_space = search_space\n", "experiment.config.search_space = search_space\n",
"experiment.config.trial_command = 'python3 main.py'\n", "experiment.config.trial_command = 'python3 main.py'\n",
"experiment.config.trial_code_directory = './'" "experiment.config.trial_code_directory = './'\n",
"experiment.config.tuner.name = 'TPE'\n",
"experiment.config.tuner.class_args['optimize_mode'] = 'maximize'\n",
"experiment.config.training_service.use_active_gpu = True"
] ]
}, },
{ {
...@@ -81,12 +65,12 @@ ...@@ -81,12 +65,12 @@
"id": "received-tattoo", "id": "received-tattoo",
"metadata": {}, "metadata": {},
"source": [ "source": [
"### 4. Start Experiment" "### 3. Start Experiment"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 4, "execution_count": 3,
"id": "pleasant-patent", "id": "pleasant-patent",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
...@@ -94,17 +78,15 @@ ...@@ -94,17 +78,15 @@
"name": "stdout", "name": "stdout",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"[2021-02-22 12:27:11] Creating experiment, Experiment ID: bj025qo4\n", "[2021-03-05 12:12:19] Creating experiment, Experiment ID: wdt0le3v\n",
"[2021-02-22 12:27:11] Connecting IPC pipe...\n", "[2021-03-05 12:12:19] Statring web server...\n",
"[2021-02-22 12:27:15] Statring web server...\n", "[2021-03-05 12:12:20] Setting up...\n",
"[2021-02-22 12:27:16] Setting up...\n", "[2021-03-05 12:12:20] Web UI URLs: http://127.0.0.1:8080 http://10.0.1.5:8080 http://172.17.0.1:8080\n"
"[2021-02-22 12:27:16] Dispatcher started\n",
"[2021-02-22 12:27:16] Web UI URLs: http://127.0.0.1:8081 http://10.0.1.5:8081 http://172.17.0.1:8081\n"
] ]
} }
], ],
"source": [ "source": [
"experiment.start(8081)" "experiment.start(8080)"
] ]
}, },
{ {
...@@ -112,12 +94,12 @@ ...@@ -112,12 +94,12 @@
"id": "miniature-prison", "id": "miniature-prison",
"metadata": {}, "metadata": {},
"source": [ "source": [
"### 5. Experiment View & Control" "### 4. Experiment View & Control"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 5, "execution_count": 4,
"id": "animated-english", "id": "animated-english",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
...@@ -127,7 +109,7 @@ ...@@ -127,7 +109,7 @@
"'RUNNING'" "'RUNNING'"
] ]
}, },
"execution_count": 5, "execution_count": 4,
"metadata": {}, "metadata": {},
"output_type": "execute_result" "output_type": "execute_result"
} }
...@@ -138,18 +120,18 @@ ...@@ -138,18 +120,18 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 6, "execution_count": 5,
"id": "alpha-ottawa", "id": "alpha-ottawa",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
"data": { "data": {
"text/plain": [ "text/plain": [
"[TrialResult(parameter={'coef0': 0.01, 'gamma': 0.01, 'degree': 1, 'kernel': 'linear', 'C': 0.1}, value=0.9866666666666667, trialJobId='B55mT'),\n", "[TrialResult(parameter={'C': 0.30000000000000004, 'kernel': 'linear', 'degree': 3, 'gamma': 0.03, 'coef0': 0.07}, value=0.9888888888888889, trialJobId='VLqU9'),\n",
" TrialResult(parameter={'coef0': 0.02, 'gamma': 0.01, 'degree': 1, 'kernel': 'linear', 'C': 0.1}, value=0.9866666666666667, trialJobId='QkhD0')]" " TrialResult(parameter={'C': 0.5, 'kernel': 'sigmoid', 'degree': 1, 'gamma': 0.03, 'coef0': 0.07}, value=0.8888888888888888, trialJobId='DLo6r')]"
] ]
}, },
"execution_count": 6, "execution_count": 5,
"metadata": {}, "metadata": {},
"output_type": "execute_result" "output_type": "execute_result"
} }
...@@ -160,18 +142,18 @@ ...@@ -160,18 +142,18 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 7, "execution_count": 6,
"id": "unique-rendering", "id": "unique-rendering",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
"data": { "data": {
"text/plain": [ "text/plain": [
"{'B55mT': [TrialMetricData(timestamp=1613996853005, trialJobId='B55mT', parameterId='0', type='FINAL', sequence=0, data=0.9866666666666667)],\n", "{'DLo6r': [TrialMetricData(timestamp=1614946351592, trialJobId='DLo6r', parameterId='1', type='FINAL', sequence=0, data=0.8888888888888888)],\n",
" 'QkhD0': [TrialMetricData(timestamp=1613996853843, trialJobId='QkhD0', parameterId='1', type='FINAL', sequence=0, data=0.9866666666666667)]}" " 'VLqU9': [TrialMetricData(timestamp=1614946351607, trialJobId='VLqU9', parameterId='0', type='FINAL', sequence=0, data=0.9888888888888889)]}"
] ]
}, },
"execution_count": 7, "execution_count": 6,
"metadata": {}, "metadata": {},
"output_type": "execute_result" "output_type": "execute_result"
} }
...@@ -185,12 +167,12 @@ ...@@ -185,12 +167,12 @@
"id": "welsh-difference", "id": "welsh-difference",
"metadata": {}, "metadata": {},
"source": [ "source": [
"### 6. Stop Experiment" "### 5. Stop Experiment"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 8, "execution_count": 7,
"id": "technological-cleanup", "id": "technological-cleanup",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
...@@ -198,10 +180,8 @@ ...@@ -198,10 +180,8 @@
"name": "stdout", "name": "stdout",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"[2021-02-22 12:28:16] Stopping experiment, please wait...\n", "[2021-03-05 12:12:40] Stopping experiment, please wait...\n",
"[2021-02-22 12:28:16] Dispatcher exiting...\n", "[2021-03-05 12:12:42] Experiment stopped\n"
"[2021-02-22 12:28:17] Experiment stopped\n",
"[2021-02-22 12:28:19] Dispatcher terminiated\n"
] ]
} }
], ],
......
# FIXME: For demonstration only. It should not be here # Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from pathlib import Path """
Example showing how to create experiment with Python code.
"""
from nni.experiment import Experiment from pathlib import Path
from nni.experiment import RemoteMachineConfig
from nni.algorithms.hpo.hyperopt_tuner import HyperoptTuner
tuner = HyperoptTuner('tpe') from nni.experiment import Experiment, RemoteMachineConfig
search_space = { search_space = {
"dropout_rate": { "_type": "uniform", "_value": [0.5, 0.9] }, "dropout_rate": { "_type": "uniform", "_value": [0.5, 0.9] },
...@@ -16,13 +17,15 @@ search_space = { ...@@ -16,13 +17,15 @@ search_space = {
"learning_rate": { "_type": "choice", "_value": [0.0001, 0.001, 0.01, 0.1] } "learning_rate": { "_type": "choice", "_value": [0.0001, 0.001, 0.01, 0.1] }
} }
experiment = Experiment(tuner, ['local', 'remote']) experiment = Experiment(['local', 'remote'])
experiment.config.experiment_name = 'test' experiment.config.experiment_name = 'test'
experiment.config.trial_concurrency = 3 experiment.config.trial_concurrency = 3
experiment.config.max_trial_number = 10 experiment.config.max_trial_number = 10
experiment.config.search_space = search_space experiment.config.search_space = search_space
experiment.config.trial_command = 'python3 mnist.py' experiment.config.trial_command = 'python3 mnist.py'
experiment.config.trial_code_directory = Path(__file__).parent experiment.config.trial_code_directory = Path(__file__).parent
experiment.config.tuner.name = 'TPE'
experiment.config.tuner.class_args['optimize_mode'] = 'maximize'
experiment.config.training_service[0].use_active_gpu = True experiment.config.training_service[0].use_active_gpu = True
experiment.config.training_service[1].reuse_mode = True experiment.config.training_service[1].reuse_mode = True
rm_conf = RemoteMachineConfig() rm_conf = RemoteMachineConfig()
......
# FIXME: For demonstration only. It should not be here # Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
"""
Example showing how to create experiment with Python code.
"""
from pathlib import Path from pathlib import Path
from nni.experiment import Experiment from nni.experiment import Experiment
from nni.algorithms.hpo.hyperopt_tuner import HyperoptTuner
tuner = HyperoptTuner('tpe')
search_space = { search_space = {
"dropout_rate": { "_type": "uniform", "_value": [0.5, 0.9] }, "dropout_rate": { "_type": "uniform", "_value": [0.5, 0.9] },
...@@ -15,13 +17,15 @@ search_space = { ...@@ -15,13 +17,15 @@ search_space = {
"learning_rate": { "_type": "choice", "_value": [0.0001, 0.001, 0.01, 0.1] } "learning_rate": { "_type": "choice", "_value": [0.0001, 0.001, 0.01, 0.1] }
} }
experiment = Experiment(tuner, 'local') experiment = Experiment('local')
experiment.config.experiment_name = 'test' experiment.config.experiment_name = 'MNIST example'
experiment.config.trial_concurrency = 2 experiment.config.trial_concurrency = 2
experiment.config.max_trial_number = 5 experiment.config.max_trial_number = 10
experiment.config.search_space = search_space experiment.config.search_space = search_space
experiment.config.trial_command = 'python3 mnist.py' experiment.config.trial_command = 'python3 mnist.py'
experiment.config.trial_code_directory = Path(__file__).parent experiment.config.trial_code_directory = Path(__file__).parent
experiment.config.tuner.name = 'TPE'
experiment.config.tuner.class_args['optimize_mode'] = 'maximize'
experiment.config.training_service.use_active_gpu = True experiment.config.training_service.use_active_gpu = True
experiment.run(8081) experiment.run(8080)
...@@ -26,8 +26,7 @@ ...@@ -26,8 +26,7 @@
"name": "stdout", "name": "stdout",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"[2021-02-25 07:50:38] Tuner not set, wait for connect...\n", "[2021-03-05 12:18:28] Connect to port 8080 success, experiment id is DH8pVfXc, status is RUNNING.\n"
"[2021-02-25 07:50:38] Connect to port 8080 success, experiment id is IF0JnfLE, status is RUNNING.\n"
] ]
} }
], ],
...@@ -53,27 +52,27 @@ ...@@ -53,27 +52,27 @@
{ {
"data": { "data": {
"text/plain": [ "text/plain": [
"{'id': 'IF0JnfLE',\n", "{'id': 'DH8pVfXc',\n",
" 'revision': 6,\n", " 'revision': 4,\n",
" 'execDuration': 28,\n", " 'execDuration': 10,\n",
" 'logDir': '/home/ningshang/nni-experiments/IF0JnfLE',\n", " 'logDir': '/home/ningshang/nni-experiments/DH8pVfXc',\n",
" 'nextSequenceId': 2,\n", " 'nextSequenceId': 1,\n",
" 'params': {'authorName': 'default',\n", " 'params': {'authorName': 'default',\n",
" 'experimentName': 'example_sklearn-classification',\n", " 'experimentName': 'example_sklearn-classification',\n",
" 'trialConcurrency': 1,\n", " 'trialConcurrency': 1,\n",
" 'maxExecDuration': 3600,\n", " 'maxExecDuration': 3600,\n",
" 'maxTrialNum': 5,\n", " 'maxTrialNum': 100,\n",
" 'searchSpace': '{\"C\": {\"_type\": \"uniform\", \"_value\": [0.1, 1]}, \"kernel\": {\"_type\": \"choice\", \"_value\": [\"linear\", \"rbf\", \"poly\", \"sigmoid\"]}, \"degree\": {\"_type\": \"choice\", \"_value\": [1, 2, 3, 4]}, \"gamma\": {\"_type\": \"uniform\", \"_value\": [0.01, 0.1]}, \"coef0\": {\"_type\": \"uniform\", \"_value\": [0.01, 0.1]}}',\n", " 'searchSpace': '{\"C\": {\"_type\": \"uniform\", \"_value\": [0.1, 1]}, \"kernel\": {\"_type\": \"choice\", \"_value\": [\"linear\", \"rbf\", \"poly\", \"sigmoid\"]}, \"degree\": {\"_type\": \"choice\", \"_value\": [1, 2, 3, 4]}, \"gamma\": {\"_type\": \"uniform\", \"_value\": [0.01, 0.1]}, \"coef0\": {\"_type\": \"uniform\", \"_value\": [0.01, 0.1]}}',\n",
" 'trainingServicePlatform': 'local',\n", " 'trainingServicePlatform': 'local',\n",
" 'tuner': {'builtinTunerName': 'TPE',\n", " 'tuner': {'builtinTunerName': 'TPE',\n",
" 'classArgs': {'optimize_mode': 'maximize'},\n", " 'classArgs': {'optimize_mode': 'maximize'},\n",
" 'checkpointDir': '/home/ningshang/nni-experiments/IF0JnfLE/checkpoint'},\n", " 'checkpointDir': '/home/ningshang/nni-experiments/DH8pVfXc/checkpoint'},\n",
" 'versionCheck': True,\n", " 'versionCheck': True,\n",
" 'clusterMetaData': [{'key': 'trial_config',\n", " 'clusterMetaData': [{'key': 'trial_config',\n",
" 'value': {'command': 'python3 main.py',\n", " 'value': {'command': 'python3 main.py',\n",
" 'codeDir': '/home/ningshang/nni/examples/trials/sklearn/classification/.',\n", " 'codeDir': '/home/ningshang/nni/examples/trials/sklearn/classification/.',\n",
" 'gpuNum': 0}}]},\n", " 'gpuNum': 0}}]},\n",
" 'startTime': 1614239412494}" " 'startTime': 1614946699989}"
] ]
}, },
"execution_count": 2, "execution_count": 2,
...@@ -90,9 +89,17 @@ ...@@ -90,9 +89,17 @@
"execution_count": 3, "execution_count": 3,
"id": "printable-bookmark", "id": "printable-bookmark",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[2021-03-05 12:18:32] (root) Successfully update maxTrialNum.\n"
]
}
],
"source": [ "source": [
"experiment.update_max_trial_number(10)" "experiment.update_max_trial_number(200)"
] ]
}, },
{ {
...@@ -104,27 +111,27 @@ ...@@ -104,27 +111,27 @@
{ {
"data": { "data": {
"text/plain": [ "text/plain": [
"{'id': 'IF0JnfLE',\n", "{'id': 'DH8pVfXc',\n",
" 'revision': 8,\n", " 'revision': 5,\n",
" 'execDuration': 32,\n", " 'execDuration': 14,\n",
" 'logDir': '/home/ningshang/nni-experiments/IF0JnfLE',\n", " 'logDir': '/home/ningshang/nni-experiments/DH8pVfXc',\n",
" 'nextSequenceId': 2,\n", " 'nextSequenceId': 1,\n",
" 'params': {'authorName': 'default',\n", " 'params': {'authorName': 'default',\n",
" 'experimentName': 'example_sklearn-classification',\n", " 'experimentName': 'example_sklearn-classification',\n",
" 'trialConcurrency': 1,\n", " 'trialConcurrency': 1,\n",
" 'maxExecDuration': 3600,\n", " 'maxExecDuration': 3600,\n",
" 'maxTrialNum': 10,\n", " 'maxTrialNum': 200,\n",
" 'searchSpace': '{\"C\": {\"_type\": \"uniform\", \"_value\": [0.1, 1]}, \"kernel\": {\"_type\": \"choice\", \"_value\": [\"linear\", \"rbf\", \"poly\", \"sigmoid\"]}, \"degree\": {\"_type\": \"choice\", \"_value\": [1, 2, 3, 4]}, \"gamma\": {\"_type\": \"uniform\", \"_value\": [0.01, 0.1]}, \"coef0\": {\"_type\": \"uniform\", \"_value\": [0.01, 0.1]}}',\n", " 'searchSpace': '{\"C\": {\"_type\": \"uniform\", \"_value\": [0.1, 1]}, \"kernel\": {\"_type\": \"choice\", \"_value\": [\"linear\", \"rbf\", \"poly\", \"sigmoid\"]}, \"degree\": {\"_type\": \"choice\", \"_value\": [1, 2, 3, 4]}, \"gamma\": {\"_type\": \"uniform\", \"_value\": [0.01, 0.1]}, \"coef0\": {\"_type\": \"uniform\", \"_value\": [0.01, 0.1]}}',\n",
" 'trainingServicePlatform': 'local',\n", " 'trainingServicePlatform': 'local',\n",
" 'tuner': {'builtinTunerName': 'TPE',\n", " 'tuner': {'builtinTunerName': 'TPE',\n",
" 'classArgs': {'optimize_mode': 'maximize'},\n", " 'classArgs': {'optimize_mode': 'maximize'},\n",
" 'checkpointDir': '/home/ningshang/nni-experiments/IF0JnfLE/checkpoint'},\n", " 'checkpointDir': '/home/ningshang/nni-experiments/DH8pVfXc/checkpoint'},\n",
" 'versionCheck': True,\n", " 'versionCheck': True,\n",
" 'clusterMetaData': [{'key': 'trial_config',\n", " 'clusterMetaData': [{'key': 'trial_config',\n",
" 'value': {'command': 'python3 main.py',\n", " 'value': {'command': 'python3 main.py',\n",
" 'codeDir': '/home/ningshang/nni/examples/trials/sklearn/classification/.',\n", " 'codeDir': '/home/ningshang/nni/examples/trials/sklearn/classification/.',\n",
" 'gpuNum': 0}}]},\n", " 'gpuNum': 0}}]},\n",
" 'startTime': 1614239412494}" " 'startTime': 1614946699989}"
] ]
}, },
"execution_count": 4, "execution_count": 4,
...@@ -154,8 +161,8 @@ ...@@ -154,8 +161,8 @@
"name": "stdout", "name": "stdout",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"[2021-02-25 07:50:49] Stopping experiment, please wait...\n", "[2021-03-05 12:18:36] Stopping experiment, please wait...\n",
"[2021-02-25 07:50:49] Experiment stopped\n" "[2021-03-05 12:18:38] Experiment stopped\n"
] ]
} }
], ],
......
...@@ -8,36 +8,17 @@ ...@@ -8,36 +8,17 @@
"## Start and Manage a New Experiment" "## Start and Manage a New Experiment"
] ]
}, },
{
"cell_type": "markdown",
"id": "immediate-daily",
"metadata": {},
"source": [
"### 1. Initialize Tuner"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "formed-grounds",
"metadata": {},
"outputs": [],
"source": [
"from nni.algorithms.hpo.gridsearch_tuner import GridSearchTuner\n",
"tuner = GridSearchTuner()"
]
},
{ {
"cell_type": "markdown", "cell_type": "markdown",
"id": "reported-somerset", "id": "reported-somerset",
"metadata": {}, "metadata": {},
"source": [ "source": [
"### 2. Configure Search Space" "### 1. Configure Search Space"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 2, "execution_count": 1,
"id": "potential-williams", "id": "potential-williams",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
...@@ -56,24 +37,27 @@ ...@@ -56,24 +37,27 @@
"id": "greek-archive", "id": "greek-archive",
"metadata": {}, "metadata": {},
"source": [ "source": [
"### 3. Configure Experiment " "### 2. Configure Experiment "
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 3, "execution_count": 2,
"id": "fiscal-expansion", "id": "fiscal-expansion",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"from nni.experiment import Experiment\n", "from nni.experiment import Experiment\n",
"experiment = Experiment(tuner, 'local')\n", "experiment = Experiment('local')\n",
"experiment.config.experiment_name = 'test'\n", "experiment.config.experiment_name = 'Example'\n",
"experiment.config.trial_concurrency = 2\n", "experiment.config.trial_concurrency = 2\n",
"experiment.config.max_trial_number = 5\n", "experiment.config.max_trial_number = 10\n",
"experiment.config.search_space = search_space\n", "experiment.config.search_space = search_space\n",
"experiment.config.trial_command = 'python3 main.py'\n", "experiment.config.trial_command = 'python3 main.py'\n",
"experiment.config.trial_code_directory = './'" "experiment.config.trial_code_directory = './'\n",
"experiment.config.tuner.name = 'TPE'\n",
"experiment.config.tuner.class_args['optimize_mode'] = 'maximize'\n",
"experiment.config.training_service.use_active_gpu = True"
] ]
}, },
{ {
...@@ -81,12 +65,12 @@ ...@@ -81,12 +65,12 @@
"id": "received-tattoo", "id": "received-tattoo",
"metadata": {}, "metadata": {},
"source": [ "source": [
"### 4. Start Experiment" "### 3. Start Experiment"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 4, "execution_count": 3,
"id": "pleasant-patent", "id": "pleasant-patent",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
...@@ -94,17 +78,15 @@ ...@@ -94,17 +78,15 @@
"name": "stdout", "name": "stdout",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"[2021-02-22 12:27:11] Creating experiment, Experiment ID: bj025qo4\n", "[2021-03-05 12:12:19] Creating experiment, Experiment ID: wdt0le3v\n",
"[2021-02-22 12:27:11] Connecting IPC pipe...\n", "[2021-03-05 12:12:19] Statring web server...\n",
"[2021-02-22 12:27:15] Statring web server...\n", "[2021-03-05 12:12:20] Setting up...\n",
"[2021-02-22 12:27:16] Setting up...\n", "[2021-03-05 12:12:20] Web UI URLs: http://127.0.0.1:8080 http://10.0.1.5:8080 http://172.17.0.1:8080\n"
"[2021-02-22 12:27:16] Dispatcher started\n",
"[2021-02-22 12:27:16] Web UI URLs: http://127.0.0.1:8081 http://10.0.1.5:8081 http://172.17.0.1:8081\n"
] ]
} }
], ],
"source": [ "source": [
"experiment.start(8081)" "experiment.start(8080)"
] ]
}, },
{ {
...@@ -112,12 +94,12 @@ ...@@ -112,12 +94,12 @@
"id": "miniature-prison", "id": "miniature-prison",
"metadata": {}, "metadata": {},
"source": [ "source": [
"### 5. Experiment View & Control" "### 4. Experiment View & Control"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 5, "execution_count": 4,
"id": "animated-english", "id": "animated-english",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
...@@ -127,7 +109,7 @@ ...@@ -127,7 +109,7 @@
"'RUNNING'" "'RUNNING'"
] ]
}, },
"execution_count": 5, "execution_count": 4,
"metadata": {}, "metadata": {},
"output_type": "execute_result" "output_type": "execute_result"
} }
...@@ -138,18 +120,18 @@ ...@@ -138,18 +120,18 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 6, "execution_count": 5,
"id": "alpha-ottawa", "id": "alpha-ottawa",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
"data": { "data": {
"text/plain": [ "text/plain": [
"[TrialResult(parameter={'coef0': 0.01, 'gamma': 0.01, 'degree': 1, 'kernel': 'linear', 'C': 0.1}, value=0.9866666666666667, trialJobId='B55mT'),\n", "[TrialResult(parameter={'C': 0.30000000000000004, 'kernel': 'linear', 'degree': 3, 'gamma': 0.03, 'coef0': 0.07}, value=0.9888888888888889, trialJobId='VLqU9'),\n",
" TrialResult(parameter={'coef0': 0.02, 'gamma': 0.01, 'degree': 1, 'kernel': 'linear', 'C': 0.1}, value=0.9866666666666667, trialJobId='QkhD0')]" " TrialResult(parameter={'C': 0.5, 'kernel': 'sigmoid', 'degree': 1, 'gamma': 0.03, 'coef0': 0.07}, value=0.8888888888888888, trialJobId='DLo6r')]"
] ]
}, },
"execution_count": 6, "execution_count": 5,
"metadata": {}, "metadata": {},
"output_type": "execute_result" "output_type": "execute_result"
} }
...@@ -160,18 +142,18 @@ ...@@ -160,18 +142,18 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 7, "execution_count": 6,
"id": "unique-rendering", "id": "unique-rendering",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
"data": { "data": {
"text/plain": [ "text/plain": [
"{'B55mT': [TrialMetricData(timestamp=1613996853005, trialJobId='B55mT', parameterId='0', type='FINAL', sequence=0, data=0.9866666666666667)],\n", "{'DLo6r': [TrialMetricData(timestamp=1614946351592, trialJobId='DLo6r', parameterId='1', type='FINAL', sequence=0, data=0.8888888888888888)],\n",
" 'QkhD0': [TrialMetricData(timestamp=1613996853843, trialJobId='QkhD0', parameterId='1', type='FINAL', sequence=0, data=0.9866666666666667)]}" " 'VLqU9': [TrialMetricData(timestamp=1614946351607, trialJobId='VLqU9', parameterId='0', type='FINAL', sequence=0, data=0.9888888888888889)]}"
] ]
}, },
"execution_count": 7, "execution_count": 6,
"metadata": {}, "metadata": {},
"output_type": "execute_result" "output_type": "execute_result"
} }
...@@ -185,12 +167,12 @@ ...@@ -185,12 +167,12 @@
"id": "welsh-difference", "id": "welsh-difference",
"metadata": {}, "metadata": {},
"source": [ "source": [
"### 6. Stop Experiment" "### 5. Stop Experiment"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 8, "execution_count": 7,
"id": "technological-cleanup", "id": "technological-cleanup",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
...@@ -198,10 +180,8 @@ ...@@ -198,10 +180,8 @@
"name": "stdout", "name": "stdout",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"[2021-02-22 12:28:16] Stopping experiment, please wait...\n", "[2021-03-05 12:12:40] Stopping experiment, please wait...\n",
"[2021-02-22 12:28:16] Dispatcher exiting...\n", "[2021-03-05 12:12:42] Experiment stopped\n"
"[2021-02-22 12:28:17] Experiment stopped\n",
"[2021-02-22 12:28:19] Dispatcher terminiated\n"
] ]
} }
], ],
......
...@@ -83,9 +83,9 @@ class ExperimentConfig(ConfigBase): ...@@ -83,9 +83,9 @@ class ExperimentConfig(ConfigBase):
def validate(self, initialized_tuner: bool = False) -> None: def validate(self, initialized_tuner: bool = False) -> None:
super().validate() super().validate()
if initialized_tuner: if initialized_tuner:
_validate_for_exp(self) _validate_for_exp(self.canonical())
else: else:
_validate_for_nnictl(self) _validate_for_nnictl(self.canonical())
if self.trial_gpu_number and hasattr(self.training_service, 'use_active_gpu'): if self.trial_gpu_number and hasattr(self.training_service, 'use_active_gpu'):
if self.training_service.use_active_gpu is None: if self.training_service.use_active_gpu is None:
raise ValueError('Please set "use_active_gpu"') raise ValueError('Please set "use_active_gpu"')
...@@ -106,7 +106,10 @@ _canonical_rules = { ...@@ -106,7 +106,10 @@ _canonical_rules = {
'trial_code_directory': util.canonical_path, 'trial_code_directory': util.canonical_path,
'max_experiment_duration': lambda value: f'{util.parse_time(value)}s' if value is not None else None, 'max_experiment_duration': lambda value: f'{util.parse_time(value)}s' if value is not None else None,
'experiment_working_directory': util.canonical_path, 'experiment_working_directory': util.canonical_path,
'tuner_gpu_indices': lambda value: [int(idx) for idx in value.split(',')] if isinstance(value, str) else value 'tuner_gpu_indices': lambda value: [int(idx) for idx in value.split(',')] if isinstance(value, str) else value,
'tuner': lambda config: None if config.name == '_none_' else config,
'assessor': lambda config: None if config.name == '_none_' else config,
'advisor': lambda config: None if config.name == '_none_' else config,
} }
_validation_rules = { _validation_rules = {
......
...@@ -14,7 +14,7 @@ _logger = logging.getLogger(__name__) ...@@ -14,7 +14,7 @@ _logger = logging.getLogger(__name__)
def to_v1_yaml(config: ExperimentConfig, skip_nnictl: bool = False) -> Dict[str, Any]: def to_v1_yaml(config: ExperimentConfig, skip_nnictl: bool = False) -> Dict[str, Any]:
config.validate(skip_nnictl) config.validate(False)
data = config.json() data = config.json()
ts = data.pop('trainingService') ts = data.pop('trainingService')
......
...@@ -3,7 +3,6 @@ import logging ...@@ -3,7 +3,6 @@ import logging
from pathlib import Path from pathlib import Path
import socket import socket
from subprocess import Popen from subprocess import Popen
from threading import Thread
import time import time
from typing import Optional, Union, List, overload, Any from typing import Optional, Union, List, overload, Any
...@@ -11,14 +10,11 @@ import colorama ...@@ -11,14 +10,11 @@ import colorama
import psutil import psutil
import nni.runtime.log import nni.runtime.log
from nni.runtime.msg_dispatcher import MsgDispatcher
from nni.tuner import Tuner
from .config import ExperimentConfig from .config import ExperimentConfig, AlgorithmConfig
from .data import TrialJob, TrialMetricData, TrialResult from .data import TrialJob, TrialMetricData, TrialResult
from . import launcher from . import launcher
from . import management from . import management
from .pipe import Pipe
from . import rest from . import rest
from ..tools.nnictl.command_utils import kill_command from ..tools.nnictl.command_utils import kill_command
...@@ -39,7 +35,7 @@ class Experiment: ...@@ -39,7 +35,7 @@ class Experiment:
""" """
@overload @overload
def __init__(self, tuner: Tuner, config: ExperimentConfig) -> None: def __init__(self, config: ExperimentConfig) -> None:
""" """
Prepare an experiment. Prepare an experiment.
...@@ -47,21 +43,19 @@ class Experiment: ...@@ -47,21 +43,19 @@ class Experiment:
Parameters Parameters
---------- ----------
tuner
A tuner instance.
config config
Experiment configuration. Experiment configuration.
""" """
... ...
@overload @overload
def __init__(self, tuner: Tuner, training_service: Union[str, List[str]]) -> None: def __init__(self, training_service: Union[str, List[str]]) -> None:
""" """
Prepare an experiment, leaving configuration fields to be set later. Prepare an experiment, leaving configuration fields to be set later.
Example usage:: Example usage::
experiment = Experiment(my_tuner, 'remote') experiment = Experiment('remote')
experiment.config.trial_command = 'python3 trial.py' experiment.config.trial_command = 'python3 trial.py'
experiment.config.machines.append(RemoteMachineConfig(ip=..., user_name=...)) experiment.config.machines.append(RemoteMachineConfig(ip=..., user_name=...))
... ...
...@@ -69,35 +63,26 @@ class Experiment: ...@@ -69,35 +63,26 @@ class Experiment:
Parameters Parameters
---------- ----------
tuner
A tuner instance.
training_service training_service
Name of training service. Name of training service.
Supported value: "local", "remote", "openpai", "aml", "kubeflow", "frameworkcontroller", "adl" and hybrid training service. Supported value: "local", "remote", "openpai", "aml", "kubeflow", "frameworkcontroller", "adl" and hybrid training service.
""" """
... ...
def __init__(self, tuner=None, config=None, training_service=None): def __init__(self, config=None, training_service=None):
self.config: Optional[ExperimentConfig] = None self.config: Optional[ExperimentConfig] = None
self.id: Optional[str] = None self.id: Optional[str] = None
self.port: Optional[int] = None self.port: Optional[int] = None
self.tuner: Optional[Tuner] = None
self._proc: Optional[Popen] = None self._proc: Optional[Popen] = None
self._pipe: Optional[Pipe] = None
self._dispatcher: Optional[MsgDispatcher] = None
self._dispatcher_thread: Optional[Thread] = None
if isinstance(tuner, Tuner): args = [config, training_service] # deal with overloading
self.tuner = tuner if isinstance(args[0], (str, list)):
if isinstance(config, (str, list)): self.config = ExperimentConfig(args[0])
config, training_service = None, config self.config.tuner = AlgorithmConfig(name='_none_', class_args={})
self.config.assessor = AlgorithmConfig(name='_none_', class_args={})
if config is None: self.config.advisor = AlgorithmConfig(name='_none_', class_args={})
self.config = ExperimentConfig(training_service)
else:
self.config = config
else: else:
_logger.warning('Tuner not set, wait for connect...') self.config = args[0]
def start(self, port: int = 8080, debug: bool = False) -> None: def start(self, port: int = 8080, debug: bool = False) -> None:
""" """
...@@ -123,18 +108,11 @@ class Experiment: ...@@ -123,18 +108,11 @@ class Experiment:
log_dir = Path.home() / f'nni-experiments/{self.id}/log' log_dir = Path.home() / f'nni-experiments/{self.id}/log'
nni.runtime.log.start_experiment_log(self.id, log_dir, debug) nni.runtime.log.start_experiment_log(self.id, log_dir, debug)
self._proc, self._pipe = launcher.start_experiment(self.id, self.config, port, debug) self._proc = launcher.start_experiment(self.id, self.config, port, debug)
assert self._proc is not None assert self._proc is not None
assert self._pipe is not None
self.port = port # port will be None if start up failed self.port = port # port will be None if start up failed
# dispatcher must be launched after pipe initialized
# the logic to launch dispatcher in background should be refactored into dispatcher api
self._dispatcher = self._create_dispatcher()
self._dispatcher_thread = Thread(target=self._dispatcher.run)
self._dispatcher_thread.start()
ips = [self.config.nni_manager_ip] ips = [self.config.nni_manager_ip]
for interfaces in psutil.net_if_addrs().values(): for interfaces in psutil.net_if_addrs().values():
for interface in interfaces: for interface in interfaces:
...@@ -144,9 +122,6 @@ class Experiment: ...@@ -144,9 +122,6 @@ class Experiment:
msg = 'Web UI URLs: ' + colorama.Fore.CYAN + ' '.join(ips) + colorama.Style.RESET_ALL msg = 'Web UI URLs: ' + colorama.Fore.CYAN + ' '.join(ips) + colorama.Style.RESET_ALL
_logger.info(msg) _logger.info(msg)
def _create_dispatcher(self): # overrided by retiarii, temporary solution
return MsgDispatcher(self.tuner, None)
def stop(self) -> None: def stop(self) -> None:
""" """
Stop background experiment. Stop background experiment.
...@@ -157,19 +132,16 @@ class Experiment: ...@@ -157,19 +132,16 @@ class Experiment:
if self.id is not None: if self.id is not None:
nni.runtime.log.stop_experiment_log(self.id) nni.runtime.log.stop_experiment_log(self.id)
if self._proc is not None: if self._proc is not None:
kill_command(self._proc.pid) try:
if self._pipe is not None: rest.delete(self.port, '/experiment')
self._pipe.close() except Exception as e:
if self._dispatcher_thread is not None: _logger.exception(e)
self._dispatcher.stopping = True _logger.warning('Cannot gracefully stop experiment, killing NNI process...')
self._dispatcher_thread.join(timeout=1) kill_command(self._proc.pid)
self.id = None self.id = None
self.port = None self.port = None
self._proc = None self._proc = None
self._pipe = None
self._dispatcher = None
self._dispatcher_thread = None
_logger.info('Experiment stopped') _logger.info('Experiment stopped')
def run(self, port: int = 8080, debug: bool = False) -> bool: def run(self, port: int = 8080, debug: bool = False) -> bool:
...@@ -216,16 +188,6 @@ class Experiment: ...@@ -216,16 +188,6 @@ class Experiment:
_logger.info('Connect to port %d success, experiment id is %s, status is %s.', port, experiment.id, status) _logger.info('Connect to port %d success, experiment id is %s, status is %s.', port, experiment.id, status)
return experiment return experiment
def _experiment_rest_get(self, port: int, api: str) -> Any:
if self.port is None:
raise RuntimeError('Experiment is not running')
return rest.get(self.port, api)
def _experiment_rest_put(self, port: int, api: str, data: Any):
if self.port is None:
raise RuntimeError('Experiment is not running')
rest.put(self.port, api, data)
def get_status(self) -> str: def get_status(self) -> str:
""" """
Return experiment status as a str. Return experiment status as a str.
...@@ -235,7 +197,7 @@ class Experiment: ...@@ -235,7 +197,7 @@ class Experiment:
str str
Experiment status. Experiment status.
""" """
resp = self._experiment_rest_get(self.port, '/check-status') resp = rest.get(self.port, '/check-status')
return resp['status'] return resp['status']
def get_trial_job(self, trial_job_id: str): def get_trial_job(self, trial_job_id: str):
...@@ -252,7 +214,7 @@ class Experiment: ...@@ -252,7 +214,7 @@ class Experiment:
TrialJob TrialJob
A `TrialJob` instance corresponding to `trial_job_id`. A `TrialJob` instance corresponding to `trial_job_id`.
""" """
resp = self._experiment_rest_get(self.port, '/trial-jobs/{}'.format(trial_job_id)) resp = rest.get(self.port, '/trial-jobs/{}'.format(trial_job_id))
return TrialJob(**resp) return TrialJob(**resp)
def list_trial_jobs(self): def list_trial_jobs(self):
...@@ -264,7 +226,7 @@ class Experiment: ...@@ -264,7 +226,7 @@ class Experiment:
list list
List of `TrialJob`. List of `TrialJob`.
""" """
resp = self._experiment_rest_get(self.port, '/trial-jobs') resp = rest.get(self.port, '/trial-jobs')
return [TrialJob(**trial_job) for trial_job in resp] return [TrialJob(**trial_job) for trial_job in resp]
def get_job_statistics(self): def get_job_statistics(self):
...@@ -276,7 +238,7 @@ class Experiment: ...@@ -276,7 +238,7 @@ class Experiment:
dict dict
Job statistics information. Job statistics information.
""" """
resp = self._experiment_rest_get(self.port, '/job-statistics') resp = rest.get(self.port, '/job-statistics')
return resp return resp
def get_job_metrics(self, trial_job_id=None): def get_job_metrics(self, trial_job_id=None):
...@@ -294,7 +256,7 @@ class Experiment: ...@@ -294,7 +256,7 @@ class Experiment:
Each key is a trialJobId, the corresponding value is a list of `TrialMetricData`. Each key is a trialJobId, the corresponding value is a list of `TrialMetricData`.
""" """
api = '/metric-data/{}'.format(trial_job_id) if trial_job_id else '/metric-data' api = '/metric-data/{}'.format(trial_job_id) if trial_job_id else '/metric-data'
resp = self._experiment_rest_get(self.port, api) resp = rest.get(self.port, api)
metric_dict = {} metric_dict = {}
for metric in resp: for metric in resp:
trial_id = metric["trialJobId"] trial_id = metric["trialJobId"]
...@@ -313,7 +275,7 @@ class Experiment: ...@@ -313,7 +275,7 @@ class Experiment:
dict dict
The profile of the experiment. The profile of the experiment.
""" """
resp = self._experiment_rest_get(self.port, '/experiment') resp = rest.get(self.port, '/experiment')
return resp return resp
def get_experiment_metadata(self, exp_id: str): def get_experiment_metadata(self, exp_id: str):
...@@ -340,7 +302,7 @@ class Experiment: ...@@ -340,7 +302,7 @@ class Experiment:
list list
The experiments metadata. The experiments metadata.
""" """
resp = self._experiment_rest_get(self.port, '/experiments-info') resp = rest.get(self.port, '/experiments-info')
return resp return resp
def export_data(self): def export_data(self):
...@@ -352,7 +314,7 @@ class Experiment: ...@@ -352,7 +314,7 @@ class Experiment:
list list
List of `TrialResult`. List of `TrialResult`.
""" """
resp = self._experiment_rest_get(self.port, '/export-data') resp = rest.get(self.port, '/export-data')
return [TrialResult(**trial_result) for trial_result in resp] return [TrialResult(**trial_result) for trial_result in resp]
def _get_query_type(self, key: str): def _get_query_type(self, key: str):
...@@ -379,7 +341,8 @@ class Experiment: ...@@ -379,7 +341,8 @@ class Experiment:
api = '/experiment{}'.format(self._get_query_type(key)) api = '/experiment{}'.format(self._get_query_type(key))
experiment_profile = self.get_experiment_profile() experiment_profile = self.get_experiment_profile()
experiment_profile['params'][key] = value experiment_profile['params'][key] = value
self._experiment_rest_put(self.port, api, experiment_profile) rest.put(self.port, api, experiment_profile)
logging.info('Successfully update %s.', key)
def update_trial_concurrency(self, value: int): def update_trial_concurrency(self, value: int):
""" """
......
...@@ -10,22 +10,19 @@ from typing import Optional, Tuple ...@@ -10,22 +10,19 @@ from typing import Optional, Tuple
import colorama import colorama
import nni_node # pylint: disable=import-error import nni_node # pylint: disable=import-error
import nni.runtime.protocol
from .config import ExperimentConfig from .config import ExperimentConfig
from .config import convert from .config import convert
from .pipe import Pipe
from . import rest from . import rest
from ..tools.nnictl.config_utils import Experiments from ..tools.nnictl.config_utils import Experiments
_logger = logging.getLogger('nni.experiment') _logger = logging.getLogger('nni.experiment')
def start_experiment(exp_id: str, config: ExperimentConfig, port: int, debug: bool) -> Tuple[Popen, Pipe]: def start_experiment(exp_id: str, config: ExperimentConfig, port: int, debug: bool) -> Popen:
pipe = None
proc = None proc = None
config.validate(initialized_tuner=True) config.validate(initialized_tuner=False)
_ensure_port_idle(port) _ensure_port_idle(port)
if isinstance(config.training_service, list): # hybrid training service if isinstance(config.training_service, list): # hybrid training service
_ensure_port_idle(port + 1, 'Hybrid training service requires an additional port') _ensure_port_idle(port + 1, 'Hybrid training service requires an additional port')
...@@ -34,12 +31,7 @@ def start_experiment(exp_id: str, config: ExperimentConfig, port: int, debug: bo ...@@ -34,12 +31,7 @@ def start_experiment(exp_id: str, config: ExperimentConfig, port: int, debug: bo
try: try:
_logger.info('Creating experiment, Experiment ID: %s', colorama.Fore.CYAN + exp_id + colorama.Style.RESET_ALL) _logger.info('Creating experiment, Experiment ID: %s', colorama.Fore.CYAN + exp_id + colorama.Style.RESET_ALL)
pipe = Pipe(exp_id) start_time, proc = _start_rest_server(config, port, debug, exp_id)
start_time, proc = _start_rest_server(config, port, debug, exp_id, pipe.path)
_logger.info('Connecting IPC pipe...')
pipe_file = pipe.connect()
nni.runtime.protocol._in_file = pipe_file
nni.runtime.protocol._out_file = pipe_file
_logger.info('Statring web server...') _logger.info('Statring web server...')
_check_rest_server(port) _check_rest_server(port)
platform = 'hybrid' if isinstance(config.training_service, list) else config.training_service.platform platform = 'hybrid' if isinstance(config.training_service, list) else config.training_service.platform
...@@ -47,16 +39,13 @@ def start_experiment(exp_id: str, config: ExperimentConfig, port: int, debug: bo ...@@ -47,16 +39,13 @@ def start_experiment(exp_id: str, config: ExperimentConfig, port: int, debug: bo
config.experiment_name, proc.pid, config.experiment_working_directory) config.experiment_name, proc.pid, config.experiment_working_directory)
_logger.info('Setting up...') _logger.info('Setting up...')
_init_experiment(config, port, debug) _init_experiment(config, port, debug)
return proc, pipe return proc
except Exception as e: except Exception as e:
_logger.error('Create experiment failed') _logger.error('Create experiment failed')
if proc is not None: if proc is not None:
with contextlib.suppress(Exception): with contextlib.suppress(Exception):
proc.kill() proc.kill()
if pipe is not None:
with contextlib.suppress(Exception):
pipe.close()
raise e raise e
...@@ -68,7 +57,7 @@ def _ensure_port_idle(port: int, message: Optional[str] = None) -> None: ...@@ -68,7 +57,7 @@ def _ensure_port_idle(port: int, message: Optional[str] = None) -> None:
raise RuntimeError(f'Port {port} is not idle {message}') raise RuntimeError(f'Port {port} is not idle {message}')
def _start_rest_server(config: ExperimentConfig, port: int, debug: bool, experiment_id: str, pipe_path: str) -> Tuple[int, Popen]: def _start_rest_server(config: ExperimentConfig, port: int, debug: bool, experiment_id: str) -> Tuple[int, Popen]:
if isinstance(config.training_service, list): if isinstance(config.training_service, list):
ts = 'hybrid' ts = 'hybrid'
else: else:
...@@ -82,7 +71,6 @@ def _start_rest_server(config: ExperimentConfig, port: int, debug: bool, experim ...@@ -82,7 +71,6 @@ def _start_rest_server(config: ExperimentConfig, port: int, debug: bool, experim
'experiment_id': experiment_id, 'experiment_id': experiment_id,
'start_mode': 'new', 'start_mode': 'new',
'log_level': 'debug' if debug else 'info', 'log_level': 'debug' if debug else 'info',
'dispatcher_pipe': pipe_path,
} }
node_dir = Path(nni_node.__path__[0]) node_dir = Path(nni_node.__path__[0])
...@@ -97,7 +85,8 @@ def _start_rest_server(config: ExperimentConfig, port: int, debug: bool, experim ...@@ -97,7 +85,8 @@ def _start_rest_server(config: ExperimentConfig, port: int, debug: bool, experim
from subprocess import CREATE_NEW_PROCESS_GROUP from subprocess import CREATE_NEW_PROCESS_GROUP
proc = Popen(cmd, cwd=node_dir, creationflags=CREATE_NEW_PROCESS_GROUP) proc = Popen(cmd, cwd=node_dir, creationflags=CREATE_NEW_PROCESS_GROUP)
else: else:
proc = Popen(cmd, cwd=node_dir) import os
proc = Popen(cmd, cwd=node_dir, preexec_fn=os.setpgrp)
return int(time.time() * 1000), proc return int(time.time() * 1000), proc
......
import logging import logging
from typing import Any from typing import Any, Optional
import requests import requests
...@@ -8,25 +8,28 @@ _logger = logging.getLogger(__name__) ...@@ -8,25 +8,28 @@ _logger = logging.getLogger(__name__)
url_template = 'http://localhost:{}/api/v1/nni{}' url_template = 'http://localhost:{}/api/v1/nni{}'
timeout = 20 timeout = 20
def get(port: int, api: str) -> Any: def request(method: str, port: Optional[int], api: str, data: Any = None) -> Any:
if port is None:
raise RuntimeError('Experiment is not running')
url = url_template.format(port, api) url = url_template.format(port, api)
resp = requests.get(url, timeout=timeout) if data is None:
resp = requests.request(method, url, timeout=timeout)
else:
resp = requests.request(method, url, json=data, timeout=timeout)
if not resp.ok: if not resp.ok:
_logger.error('rest request GET %s %s failed: %s %s', port, api, resp.status_code, resp.text) _logger.error('rest request %s %s failed: %s %s', method.upper(), url, resp.status_code, resp.text)
resp.raise_for_status() resp.raise_for_status()
return resp.json() if method.lower() in ['get', 'post']:
return resp.json()
def post(port: int, api: str, data: Any) -> Any: def get(port: Optional[int], api: str) -> Any:
url = url_template.format(port, api) return request('get', port, api)
resp = requests.post(url, json=data, timeout=timeout)
if not resp.ok:
_logger.error('rest request POST %s %s failed: %s %s', port, api, resp.status_code, resp.text)
resp.raise_for_status()
return resp.json()
def put(port: int, api: str, data: Any) -> None: def post(port: Optional[int], api: str, data: Any) -> Any:
url = url_template.format(port, api) return request('post', port, api, data)
resp = requests.put(url, json=data, timeout=timeout)
if not resp.ok: def put(port: Optional[int], api: str, data: Any) -> None:
_logger.error('rest request PUT %s %s failed: %s', port, api, resp.status_code) request('put', port, api, data)
resp.raise_for_status()
def delete(port: Optional[int], api: str) -> None:
request('delete', port, api)
...@@ -84,6 +84,8 @@ abstract class Manager { ...@@ -84,6 +84,8 @@ abstract class Manager {
public abstract startExperiment(experimentParams: ExperimentParams): Promise<string>; public abstract startExperiment(experimentParams: ExperimentParams): Promise<string>;
public abstract resumeExperiment(readonly: boolean): Promise<void>; public abstract resumeExperiment(readonly: boolean): Promise<void>;
public abstract stopExperiment(): Promise<void>; public abstract stopExperiment(): Promise<void>;
public abstract stopExperimentTopHalf(): Promise<void>;
public abstract stopExperimentBottomHalf(): Promise<void>;
public abstract getExperimentProfile(): Promise<ExperimentProfile>; public abstract getExperimentProfile(): Promise<ExperimentProfile>;
public abstract updateExperimentProfile(experimentProfile: ExperimentProfile, updateType: ProfileUpdateType): Promise<void>; public abstract updateExperimentProfile(experimentProfile: ExperimentProfile, updateType: ProfileUpdateType): Promise<void>;
public abstract importData(data: string): Promise<void>; public abstract importData(data: string): Promise<void>;
......
...@@ -25,6 +25,7 @@ import { ...@@ -25,6 +25,7 @@ import {
REPORT_METRIC_DATA, REQUEST_TRIAL_JOBS, SEND_TRIAL_JOB_PARAMETER, TERMINATE, TRIAL_END, UPDATE_SEARCH_SPACE, IMPORT_DATA REPORT_METRIC_DATA, REQUEST_TRIAL_JOBS, SEND_TRIAL_JOB_PARAMETER, TERMINATE, TRIAL_END, UPDATE_SEARCH_SPACE, IMPORT_DATA
} from './commands'; } from './commands';
import { createDispatcherInterface, createDispatcherPipeInterface, IpcInterface } from './ipcInterface'; import { createDispatcherInterface, createDispatcherPipeInterface, IpcInterface } from './ipcInterface';
import { NNIRestServer } from '../rest_server/nniRestServer';
/** /**
* NNIManager which implements Manager interface * NNIManager which implements Manager interface
...@@ -296,10 +297,71 @@ class NNIManager implements Manager { ...@@ -296,10 +297,71 @@ class NNIManager implements Manager {
} }
public async stopExperiment(): Promise<void> { public async stopExperiment(): Promise<void> {
await this.stopExperimentTopHalf();
await this.stopExperimentBottomHalf();
}
public async stopExperimentTopHalf(): Promise<void> {
this.setStatus('STOPPING'); this.setStatus('STOPPING');
this.log.info('Stopping experiment, cleaning up ...'); this.log.info('Stopping experiment, cleaning up ...');
await this.experimentDoneCleanUp();
if (this.dispatcher === undefined) {
this.log.error('Tuner has not been setup');
return;
}
this.trainingService.removeTrialJobMetricListener(this.trialJobMetricListener);
if (this.dispatcherPid > 0) {
this.dispatcher.sendCommand(TERMINATE);
// gracefully terminate tuner and assessor here, wait at most 30 seconds.
for (let i: number = 0; i < 30; i++) {
if (!await isAlive(this.dispatcherPid)) {
break;
}
await delay(1000);
}
await killPid(this.dispatcherPid);
}
this.dispatcher = undefined;
}
public async stopExperimentBottomHalf(): Promise<void> {
const trialJobList: TrialJobDetail[] = await this.trainingService.listTrialJobs();
// DON'T try to make it in parallel, the training service may not handle it well.
// If there is performance concern, consider to support batch cancellation on training service.
for (const trialJob of trialJobList) {
if (trialJob.status === 'RUNNING' ||
trialJob.status === 'WAITING') {
try {
this.log.info(`cancelTrialJob: ${trialJob.id}`);
await this.trainingService.cancelTrialJob(trialJob.id);
} catch (error) {
this.log.debug(`ignorable error on canceling trial ${trialJob.id}. ${error}`);
}
}
}
await this.trainingService.cleanUp();
if (this.experimentProfile.endTime === undefined) {
this.setEndtime();
}
await this.storeExperimentProfile();
this.setStatus('STOPPED');
this.log.info('Experiment stopped.'); this.log.info('Experiment stopped.');
let hasError: boolean = false;
try {
this.experimentManager.stop();
this.dataStore.close();
await component.get<NNIRestServer>(NNIRestServer).stop();
} catch (err) {
hasError = true;
this.log.error(`${err.stack}`);
} finally {
this.log.close();
process.exit(hasError ? 1 : 0);
}
} }
public async getMetricData(trialJobId?: string, metricType?: MetricType): Promise<MetricDataRecord[]> { public async getMetricData(trialJobId?: string, metricType?: MetricType): Promise<MetricDataRecord[]> {
...@@ -437,45 +499,6 @@ class NNIManager implements Manager { ...@@ -437,45 +499,6 @@ class NNIManager implements Manager {
return; return;
} }
private async experimentDoneCleanUp(): Promise<void> {
if (this.dispatcher === undefined) {
throw new Error('Error: tuner has not been setup');
}
this.trainingService.removeTrialJobMetricListener(this.trialJobMetricListener);
if (this.dispatcherPid > 0) {
this.dispatcher.sendCommand(TERMINATE);
let tunerAlive: boolean = true;
// gracefully terminate tuner and assessor here, wait at most 30 seconds.
for (let i: number = 0; i < 30; i++) {
if (!tunerAlive) { break; }
tunerAlive = await isAlive(this.dispatcherPid);
await delay(1000);
}
await killPid(this.dispatcherPid);
}
const trialJobList: TrialJobDetail[] = await this.trainingService.listTrialJobs();
// DON'T try to make it in parallel, the training service may not handle it well.
// If there is performance concern, consider to support batch cancellation on training service.
for (const trialJob of trialJobList) {
if (trialJob.status === 'RUNNING' ||
trialJob.status === 'WAITING') {
try {
this.log.info(`cancelTrialJob: ${trialJob.id}`);
await this.trainingService.cancelTrialJob(trialJob.id);
} catch (error) {
this.log.debug(`ignorable error on canceling trial ${trialJob.id}. ${error}`);
}
}
}
await this.trainingService.cleanUp();
if (this.experimentProfile.endTime === undefined) {
this.setEndtime();
}
await this.storeExperimentProfile();
this.setStatus('STOPPED');
}
private async periodicallyUpdateExecDuration(): Promise<void> { private async periodicallyUpdateExecDuration(): Promise<void> {
let count: number = 1; let count: number = 1;
while (!['ERROR', 'STOPPING', 'STOPPED'].includes(this.status.status)) { while (!['ERROR', 'STOPPING', 'STOPPED'].includes(this.status.status)) {
......
...@@ -174,25 +174,8 @@ mkDirP(getLogDir()) ...@@ -174,25 +174,8 @@ mkDirP(getLogDir())
console.error(`Failed to create log dir: ${err.stack}`); console.error(`Failed to create log dir: ${err.stack}`);
}); });
async function cleanUp(): Promise<void> { function cleanUp(): void {
const log: Logger = getLogger(); (component.get(Manager) as Manager).stopExperiment();
let hasError: boolean = false;
try {
const nniManager: Manager = component.get(Manager);
await nniManager.stopExperiment();
const experimentManager: ExperimentManager = component.get(ExperimentManager);
await experimentManager.stop();
const ds: DataStore = component.get(DataStore);
await ds.close();
const restServer: NNIRestServer = component.get(NNIRestServer);
await restServer.stop();
} catch (err) {
hasError = true;
log.error(`${err.stack}`);
} finally {
log.close();
process.exit(hasError ? 1 : 0);
}
} }
process.on('SIGTERM', cleanUp); process.on('SIGTERM', cleanUp);
......
...@@ -64,6 +64,7 @@ class NNIRestHandler { ...@@ -64,6 +64,7 @@ class NNIRestHandler {
this.getTrialLog(router); this.getTrialLog(router);
this.exportData(router); this.exportData(router);
this.getExperimentsInfo(router); this.getExperimentsInfo(router);
this.stop(router);
// Express-joi-validator configuration // Express-joi-validator configuration
router.use((err: any, _req: Request, res: Response, _next: any) => { router.use((err: any, _req: Request, res: Response, _next: any) => {
...@@ -317,6 +318,15 @@ class NNIRestHandler { ...@@ -317,6 +318,15 @@ class NNIRestHandler {
}); });
} }
private stop(router: Router): void {
router.delete('/experiment', (req: Request, res: Response) => {
this.nniManager.stopExperimentTopHalf().then(() => {
res.send();
this.nniManager.stopExperimentBottomHalf();
});
});
}
private setErrorPathForFailedJob(jobInfo: TrialJobInfo): TrialJobInfo { private setErrorPathForFailedJob(jobInfo: TrialJobInfo): TrialJobInfo {
if (jobInfo === undefined || jobInfo.status !== 'FAILED' || jobInfo.logPath === undefined) { if (jobInfo === undefined || jobInfo.status !== 'FAILED' || jobInfo.logPath === undefined) {
return jobInfo; return jobInfo;
......
...@@ -114,6 +114,12 @@ export class MockedNNIManager extends Manager { ...@@ -114,6 +114,12 @@ export class MockedNNIManager extends Manager {
public stopExperiment(): Promise<void> { public stopExperiment(): Promise<void> {
throw new MethodNotImplementedError(); throw new MethodNotImplementedError();
} }
public stopExperimentTopHalf(): Promise<void> {
throw new MethodNotImplementedError();
}
public stopExperimentBottomHalf(): Promise<void> {
throw new MethodNotImplementedError();
}
public getMetricData(trialJobId: string, metricType: MetricType): Promise<MetricDataRecord[]> { public getMetricData(trialJobId: string, metricType: MetricType): Promise<MetricDataRecord[]> {
throw new MethodNotImplementedError(); throw new MethodNotImplementedError();
} }
......
...@@ -17,4 +17,4 @@ ...@@ -17,4 +17,4 @@
"exclude": [ "exclude": [
"node_modules" "node_modules"
] ]
} }
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment