"docs/vscode:/vscode.git/clone" did not exist on "cfda0dae8cd5df5ef96c53ce4d488579322deebb"
Unverified Commit bc0f8f33 authored by liuzhe-lz's avatar liuzhe-lz Committed by GitHub
Browse files

Refactor code hierarchy part 3: Unit test (#3037)

parent 80b6cb3b
...@@ -39,7 +39,7 @@ ...@@ -39,7 +39,7 @@
### Random Mutator ### Random Mutator
```eval_rst ```eval_rst
.. autoclass:: nni.nas.pytorch.random.RandomMutator .. autoclass:: nni.algorithms.nas.pytorch.random.RandomMutator
:members: :members:
``` ```
...@@ -74,9 +74,9 @@ ...@@ -74,9 +74,9 @@
### Distributed NAS ### Distributed NAS
```eval_rst ```eval_rst
.. autofunction:: nni.nas.pytorch.classic_nas.get_and_apply_next_architecture .. autofunction:: nni.algorithms.nas.pytorch.classic_nas.get_and_apply_next_architecture
.. autoclass:: nni.nas.pytorch.classic_nas.mutator.ClassicMutator .. autoclass:: nni.algorithms.nas.pytorch.classic_nas.mutator.ClassicMutator
:members: :members:
``` ```
......
...@@ -90,13 +90,13 @@ By default, it will use `architecture_final.json`. This architecture is provided ...@@ -90,13 +90,13 @@ By default, it will use `architecture_final.json`. This architecture is provided
### PyTorch ### PyTorch
```eval_rst ```eval_rst
.. autoclass:: nni.nas.pytorch.spos.SPOSEvolution .. autoclass:: nni.algorithms.nas.pytorch.spos.SPOSEvolution
:members: :members:
.. autoclass:: nni.nas.pytorch.spos.SPOSSupernetTrainer .. autoclass:: nni.algorithms.nas.pytorch.spos.SPOSSupernetTrainer
:members: :members:
.. autoclass:: nni.nas.pytorch.spos.SPOSSupernetTrainingMutator .. autoclass:: nni.algorithms.nas.pytorch.spos.SPOSSupernetTrainingMutator
:members: :members:
``` ```
......
...@@ -9,14 +9,14 @@ If a user want to implement a customized Advisor, she/he only needs to: ...@@ -9,14 +9,14 @@ If a user want to implement a customized Advisor, she/he only needs to:
**1. Define an Advisor inheriting from the MsgDispatcherBase class.** For example: **1. Define an Advisor inheriting from the MsgDispatcherBase class.** For example:
```python ```python
from nni.msg_dispatcher_base import MsgDispatcherBase from nni.runtime.msg_dispatcher_base import MsgDispatcherBase
class CustomizedAdvisor(MsgDispatcherBase): class CustomizedAdvisor(MsgDispatcherBase):
def __init__(self, ...): def __init__(self, ...):
... ...
``` ```
**2. Implement the methods with prefix `handle_` except `handle_request`**.. You might find [docs](https://nni.readthedocs.io/en/latest/sdk_reference.html#nni.msg_dispatcher_base.MsgDispatcherBase) for `MsgDispatcherBase` helpful. **2. Implement the methods with prefix `handle_` except `handle_request`**.. You might find [docs](https://nni.readthedocs.io/en/latest/sdk_reference.html#nni.runtime.msg_dispatcher_base.MsgDispatcherBase) for `MsgDispatcherBase` helpful.
**3. Configure your customized Advisor in experiment YAML config file.** **3. Configure your customized Advisor in experiment YAML config file.**
......
...@@ -22,31 +22,31 @@ ...@@ -22,31 +22,31 @@
.. autoclass:: nni.tuner.Tuner .. autoclass:: nni.tuner.Tuner
:members: :members:
.. autoclass:: nni.hyperopt_tuner.hyperopt_tuner.HyperoptTuner .. autoclass:: nni.algorithms.hpo.hyperopt_tuner.hyperopt_tuner.HyperoptTuner
:members: :members:
.. autoclass:: nni.evolution_tuner.evolution_tuner.EvolutionTuner .. autoclass:: nni.algorithms.hpo.evolution_tuner.evolution_tuner.EvolutionTuner
:members: :members:
.. autoclass:: nni.smac_tuner.SMACTuner .. autoclass:: nni.algorithms.hpo.smac_tuner.SMACTuner
:members: :members:
.. autoclass:: nni.gridsearch_tuner.GridSearchTuner .. autoclass:: nni.algorithms.hpo.gridsearch_tuner.GridSearchTuner
:members: :members:
.. autoclass:: nni.networkmorphism_tuner.networkmorphism_tuner.NetworkMorphismTuner .. autoclass:: nni.algorithms.hpo.networkmorphism_tuner.networkmorphism_tuner.NetworkMorphismTuner
:members: :members:
.. autoclass:: nni.metis_tuner.metis_tuner.MetisTuner .. autoclass:: nni.algorithms.hpo.metis_tuner.metis_tuner.MetisTuner
:members: :members:
.. autoclass:: nni.ppo_tuner.PPOTuner .. autoclass:: nni.algorithms.hpo.ppo_tuner.PPOTuner
:members: :members:
.. autoclass:: nni.batch_tuner.batch_tuner.BatchTuner .. autoclass:: nni.algorithms.hpo.batch_tuner.batch_tuner.BatchTuner
:members: :members:
.. autoclass:: nni.gp_tuner.gp_tuner.GPTuner .. autoclass:: nni.algorithms.hpo.gp_tuner.gp_tuner.GPTuner
:members: :members:
``` ```
...@@ -59,23 +59,23 @@ ...@@ -59,23 +59,23 @@
.. autoclass:: nni.assessor.AssessResult .. autoclass:: nni.assessor.AssessResult
:members: :members:
.. autoclass:: nni.curvefitting_assessor.CurvefittingAssessor .. autoclass:: nni.algorithms.hpo.curvefitting_assessor.CurvefittingAssessor
:members: :members:
.. autoclass:: nni.medianstop_assessor.MedianstopAssessor .. autoclass:: nni.algorithms.hpo.medianstop_assessor.MedianstopAssessor
:members: :members:
``` ```
## Advisor ## Advisor
```eval_rst ```eval_rst
.. autoclass:: nni.msg_dispatcher_base.MsgDispatcherBase .. autoclass:: nni.runtime.msg_dispatcher_base.MsgDispatcherBase
:members: :members:
.. autoclass:: nni.hyperband_advisor.hyperband_advisor.Hyperband .. autoclass:: nni.algorithms.hpo.hyperband_advisor.hyperband_advisor.Hyperband
:members: :members:
.. autoclass:: nni.bohb_advisor.bohb_advisor.BOHB .. autoclass:: nni.algorithms.hpo.bohb_advisor.bohb_advisor.BOHB
:members: :members:
``` ```
......
...@@ -4,14 +4,14 @@ ...@@ -4,14 +4,14 @@
import torch import torch
from apex.parallel import DistributedDataParallel # pylint: disable=import-error from apex.parallel import DistributedDataParallel # pylint: disable=import-error
from nni.nas.pytorch.darts import DartsMutator # pylint: disable=wrong-import-order from nni.algorithms.nas.pytorch.darts import DartsMutator # pylint: disable=wrong-import-order
from nni.nas.pytorch.mutables import LayerChoice # pylint: disable=wrong-import-order from nni.nas.pytorch.mutables import LayerChoice # pylint: disable=wrong-import-order
from nni.nas.pytorch.mutator import Mutator # pylint: disable=wrong-import-order from nni.nas.pytorch.mutator import Mutator # pylint: disable=wrong-import-order
class RegularizedDartsMutator(DartsMutator): class RegularizedDartsMutator(DartsMutator):
""" """
This is :class:`~nni.nas.pytorch.darts.DartsMutator` basically, with two differences. This is :class:`~nni.algorithms.nas.pytorch.darts.DartsMutator` basically, with two differences.
1. Choices can be cut (bypassed). This is done by ``cut_choices``. Cutted choices will not be used in 1. Choices can be cut (bypassed). This is done by ``cut_choices``. Cutted choices will not be used in
forward pass and thus consumes no memory. forward pass and thus consumes no memory.
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
import logging import logging
import numpy as np import numpy as np
from nni.nas.pytorch.random import RandomMutator from nni.algorithms.nas.pytorch.random import RandomMutator
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
......
...@@ -71,7 +71,8 @@ class Mutator(BaseMutator): ...@@ -71,7 +71,8 @@ class Mutator(BaseMutator):
axis = 0 axis = 0
else: else:
axis = -1 axis = -1
return tf.concat(tensor_list, axis=axis) return tf.concat(tensor_list, axis=axis) # pylint: disable=E1120,E1123
# pylint issue #3613
raise ValueError('Unrecognized reduction policy: "{}'.format(reduction_type)) raise ValueError('Unrecognized reduction policy: "{}'.format(reduction_type))
def _get_decision(self, mutable): def _get_decision(self, mutable):
......
...@@ -22,6 +22,7 @@ def generate_search_space(code_dir): ...@@ -22,6 +22,7 @@ def generate_search_space(code_dir):
Return a serializable search space object. Return a serializable search space object.
code_dir: directory path of source files (str) code_dir: directory path of source files (str)
""" """
code_dir = str(code_dir)
search_space = {} search_space = {}
if code_dir.endswith(slash): if code_dir.endswith(slash):
...@@ -65,6 +66,8 @@ def expand_annotations(src_dir, dst_dir, exp_id='', trial_id='', nas_mode=None): ...@@ -65,6 +66,8 @@ def expand_annotations(src_dir, dst_dir, exp_id='', trial_id='', nas_mode=None):
dst_dir: directory to place generated files (str) dst_dir: directory to place generated files (str)
nas_mode: the mode of NAS given that NAS interface is used nas_mode: the mode of NAS given that NAS interface is used
""" """
src_dir, dst_dir = str(src_dir), str(dst_dir)
if src_dir[-1] == slash: if src_dir[-1] == slash:
src_dir = src_dir[:-1] src_dir = src_dir[:-1]
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
import ast import ast
import astor import astor
from .utils import ast_Num, ast_Str from .utils import ast_Num, ast_Str, lineno
# pylint: disable=unidiomatic-typecheck # pylint: disable=unidiomatic-typecheck
...@@ -274,7 +274,7 @@ class Transformer(ast.NodeTransformer): ...@@ -274,7 +274,7 @@ class Transformer(ast.NodeTransformer):
def visit(self, node): def visit(self, node):
if isinstance(node, (ast.expr, ast.stmt)): if isinstance(node, (ast.expr, ast.stmt)):
self.last_line = node.lineno self.last_line = lineno(node)
# do nothing for root # do nothing for root
if not self.stack: if not self.stack:
...@@ -316,7 +316,7 @@ class Transformer(ast.NodeTransformer): ...@@ -316,7 +316,7 @@ class Transformer(ast.NodeTransformer):
return parse_annotation(string[1:]) # expand annotation string to code return parse_annotation(string[1:]) # expand annotation string to code
if string.startswith('@nni.mutable_layers'): if string.startswith('@nni.mutable_layers'):
nodes = parse_annotation_mutable_layers(string[1:], node.lineno, self.nas_mode) nodes = parse_annotation_mutable_layers(string[1:], lineno(node), self.nas_mode)
return nodes return nodes
if string.startswith('@nni.variable') \ if string.startswith('@nni.variable') \
......
...@@ -6,7 +6,7 @@ import numbers ...@@ -6,7 +6,7 @@ import numbers
import astor import astor
from .utils import ast_Num, ast_Str from .utils import ast_Num, ast_Str, lineno
# pylint: disable=unidiomatic-typecheck # pylint: disable=unidiomatic-typecheck
...@@ -65,7 +65,7 @@ class SearchSpaceGenerator(ast.NodeTransformer): ...@@ -65,7 +65,7 @@ class SearchSpaceGenerator(ast.NodeTransformer):
if func not in _ss_funcs: if func not in _ss_funcs:
return node return node
self.last_line = node.lineno self.last_line = lineno(node)
if func == 'mutable_layer': if func == 'mutable_layer':
self.generate_mutable_layer_search_space(node.args) self.generate_mutable_layer_search_space(node.args)
......
...@@ -5,7 +5,7 @@ import ast ...@@ -5,7 +5,7 @@ import ast
import astor import astor
from nni.tools.nnictl.common_utils import print_warning from nni.tools.nnictl.common_utils import print_warning
from .utils import ast_Num, ast_Str from .utils import ast_Num, ast_Str, lineno
# pylint: disable=unidiomatic-typecheck # pylint: disable=unidiomatic-typecheck
...@@ -257,7 +257,7 @@ class Transformer(ast.NodeTransformer): ...@@ -257,7 +257,7 @@ class Transformer(ast.NodeTransformer):
def visit(self, node): def visit(self, node):
if isinstance(node, (ast.expr, ast.stmt)): if isinstance(node, (ast.expr, ast.stmt)):
self.last_line = node.lineno self.last_line = lineno(node)
# do nothing for root # do nothing for root
if not self.stack: if not self.stack:
...@@ -311,7 +311,7 @@ class Transformer(ast.NodeTransformer): ...@@ -311,7 +311,7 @@ class Transformer(ast.NodeTransformer):
args=[ast_Str(s='nni.report_final_result: '), arg], keywords=[])) args=[ast_Str(s='nni.report_final_result: '), arg], keywords=[]))
if string.startswith('@nni.mutable_layers'): if string.startswith('@nni.mutable_layers'):
return parse_annotation_mutable_layers(string[1:], node.lineno) return parse_annotation_mutable_layers(string[1:], lineno(node))
if string.startswith('@nni.variable') \ if string.startswith('@nni.variable') \
or string.startswith('@nni.function_choice'): or string.startswith('@nni.function_choice'):
......
...@@ -7,9 +7,16 @@ from sys import version_info ...@@ -7,9 +7,16 @@ from sys import version_info
if version_info >= (3, 8): if version_info >= (3, 8):
ast_Num = ast_Str = ast_Bytes = ast_NameConstant = ast_Ellipsis = ast.Constant ast_Num = ast_Str = ast_Bytes = ast_NameConstant = ast_Ellipsis = ast.Constant
def lineno(ast_node):
return ast_node.end_lineno
else: else:
ast_Num = ast.Num ast_Num = ast.Num
ast_Str = ast.Str ast_Str = ast.Str
ast_Bytes = ast.Bytes ast_Bytes = ast.Bytes
ast_NameConstant = ast.NameConstant ast_NameConstant = ast.NameConstant
ast_Ellipsis = ast.Ellipsis ast_Ellipsis = ast.Ellipsis
def lineno(ast_node):
return ast_node.lineno
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
"""
TypeScript modules of NNI.
As a python package this only contains "package data".
"""
# To reduce debug cost, steps are sorted differently on each platform,
# so that a bug in any module will cause at least one platform to fail quickly.
jobs:
- job: 'ubuntu_latest'
pool:
# FIXME: In ubuntu-20.04 Python interpreter crashed during SMAC UT
vmImage: 'ubuntu-18.04'
# This platform tests lint and doc first.
steps:
- script: |
set -e
python3 -m pip install -U --upgrade pip setuptools
python3 -m pip install -U pytest coverage
python3 -m pip install -U pylint flake8
echo "##vso[task.setvariable variable=PATH]${HOME}/.local/bin:${PATH}"
displayName: 'Install Python tools'
- script: |
python3 setup.py develop
displayName: 'Install NNI'
- script: |
set -e
cd ts/nni_manager
yarn eslint
cd ../webui
yarn eslint
displayName: 'ESLint'
- script: |
set -e
sudo apt-get install -y pandoc
python3 -m pip install -U --upgrade pygments
python3 -m pip install -U torch==1.5.0+cpu torchvision==0.6.0+cpu -f https://download.pytorch.org/whl/torch_stable.html
python3 -m pip install -U tensorflow==2.3.1
python3 -m pip install -U keras==2.4.2
python3 -m pip install -U gym onnx peewee thop
python3 -m pip install -U sphinx==1.8.3 sphinx-argparse==0.2.5 sphinx-markdown-tables==0.0.9 sphinx-rtd-theme==0.4.2 sphinxcontrib-websupport==1.1.0 recommonmark==0.5.0 nbsphinx
sudo apt-get install swig -y
nnictl package install --name=SMAC
nnictl package install --name=BOHB
displayName: 'Install extra dependencies'
- script: |
set -e
python3 -m pylint --rcfile pylintrc nni
python3 -m flake8 nni --count --select=E9,F63,F72,F82 --show-source --statistics
EXCLUDES=examples/trials/mnist-nas/*/mnist*.py,examples/trials/nas_cifar10/src/cifar10/general_child.py
python3 -m flake8 examples --count --exclude=$EXCLUDES --select=E9,F63,F72,F82 --show-source --statistics
displayName: 'pylint and flake8'
- script: |
cd docs/en_US
sphinx-build -M html . _build -W --keep-going -T
displayName: 'Check Sphinx documentation'
- script: |
cd test
python3 -m pytest ut
displayName: 'Python unit test'
- script: |
set -e
cd ts/nni_manager
yarn test
cd ../nasui
CI=true yarn test
displayName: 'TypeScript unit test'
- script: |
cd test
python3 nni_test/nnitest/run_tests.py --config config/pr_tests.yml
displayName: 'Simple integration test'
- job: 'ubuntu_legacy'
pool:
vmImage: 'ubuntu-18.04'
# This platform runs integration test first.
steps:
- script: |
set -e
python3 -m pip install -U --upgrade pip setuptools
python3 -m pip install -U pytest coverage
echo "##vso[task.setvariable variable=PATH]${HOME}/.local/bin:${PATH}"
displayName: 'Install Python tools'
- script: |
python3 setup.py develop
displayName: 'Install NNI'
- script: |
set -e
python3 -m pip install -U torch==1.3.1+cpu torchvision==0.4.2+cpu -f https://download.pytorch.org/whl/torch_stable.html
python3 -m pip install -U tensorflow==1.15.2
python3 -m pip install -U keras==2.1.6
python3 -m pip install -U gym onnx peewee
sudo apt-get install swig -y
nnictl package install --name=SMAC
nnictl package install --name=BOHB
displayName: 'Install extra dependencies'
- script: |
cd test
python3 nni_test/nnitest/run_tests.py --config config/pr_tests.yml
displayName: 'Simple integration test'
- script: |
cd test
python3 -m pytest ut
displayName: 'Python unit test'
- script: |
set -e
cd ts/nni_manager
yarn test
cd ../nasui
CI=true yarn test
displayName: 'TypeScript unit test'
- job: 'macos'
pool:
vmImage: 'macOS-10.15'
# This platform runs TypeScript unit test first.
steps:
- script: |
set -e
export PYTHON38_BIN_DIR=/usr/local/Cellar/python@3.8/`ls /usr/local/Cellar/python@3.8`/bin
echo "##vso[task.setvariable variable=PATH]${PYTHON38_BIN_DIR}:${HOME}/Library/Python/3.8/bin:${PATH}"
python3 -m pip install -U --upgrade pip setuptools
python3 -m pip install -U pytest coverage
displayName: 'Install Python tools'
- script: |
python3 setup.py develop
displayName: 'Install NNI'
- script: |
set -e
cd ts/nni_manager
yarn test
cd ../nasui
CI=true yarn test
displayName: 'TypeScript unit test'
- script: |
set -e
# pytorch Mac binary does not support CUDA, default is cpu version
python3 -m pip install -U torchvision==0.6.0 torch==1.5.0
python3 -m pip install -U tensorflow==2.3.1
brew install swig@3
rm -f /usr/local/bin/swig
ln -s /usr/local/opt/swig\@3/bin/swig /usr/local/bin/swig
nnictl package install --name=SMAC
displayName: 'Install extra dependencies'
- script: |
cd test
python3 -m pytest ut
displayName: 'Python unit test'
- script: |
cd test
python3 nni_test/nnitest/run_tests.py --config config/pr_tests.yml
displayName: 'Simple integration test'
# FIXME: Windows UT is still under debugging
- job: 'windows'
pool:
vmImage: 'windows-2019'
# This platform runs Python unit test first.
steps:
- script: |
python -m pip install -U --upgrade pip setuptools
python -m pip install -U pytest coverage
displayName: 'Install Python tools'
- script: |
python setup.py develop
displayName: 'Install NNI'
- script: |
python -m pip install -U scikit-learn==0.23.2
python -m pip install -U torch==1.5.0+cpu torchvision==0.6.0+cpu -f https://download.pytorch.org/whl/torch_stable.html
python -m pip install -U tensorflow==2.3.1
displayName: 'Install extra dependencies'
- script: |
cd test
python -m pytest ut
displayName: 'Python unit test'
continueOnError: true
- script: |
cd ts/nni_manager
yarn test
displayName: 'TypeScript unit test'
continueOnError: true
- script: |
cd test
python nni_test/nnitest/run_tests.py --config config/pr_tests.yml
displayName: 'Simple integration test'
continueOnError: true
...@@ -130,10 +130,13 @@ def _find_python_packages(): ...@@ -130,10 +130,13 @@ def _find_python_packages():
return sorted(packages) + ['nni_node'] return sorted(packages) + ['nni_node']
def _find_node_files(): def _find_node_files():
if not os.path.exists('nni_node'):
return []
files = [] files = []
for dirpath, dirnames, filenames in os.walk('nni_node'): for dirpath, dirnames, filenames in os.walk('nni_node'):
for filename in filenames: for filename in filenames:
files.append((dirpath + '/' + filename)[len('nni_node/'):]) files.append((dirpath + '/' + filename)[len('nni_node/'):])
if '__init__.py' in files:
files.remove('__init__.py') files.remove('__init__.py')
return sorted(files) return sorted(files)
...@@ -195,7 +198,8 @@ _temp_files = [ ...@@ -195,7 +198,8 @@ _temp_files = [
# unit test # unit test
'test/model_path/', 'test/model_path/',
'test/temp.json', 'test/temp.json',
'test/ut/sdk/*.pth' 'test/ut/sdk/*.pth',
'test/ut/tools/annotation/_generated/'
] ]
......
...@@ -22,7 +22,7 @@ import tarfile ...@@ -22,7 +22,7 @@ import tarfile
from zipfile import ZipFile from zipfile import ZipFile
node_version = 'v10.22.1' node_version = 'v10.23.0'
yarn_version = 'v1.22.10' yarn_version = 'v1.22.10'
...@@ -32,13 +32,15 @@ def build(release): ...@@ -32,13 +32,15 @@ def build(release):
`release` is the version number without leading letter "v". `release` is the version number without leading letter "v".
If `release` is None or empty, this is a development build and uses symlinks; If `release` is None or empty, this is a development build and uses symlinks on Linux/macOS;
otherwise this is a release build and copies files instead. otherwise this is a release build and copies files instead.
On Windows it always copies files because creating symlink requires extra privilege.
""" """
if release or not os.environ.get('GLOBAL_TOOLCHAIN'): if release or not os.environ.get('GLOBAL_TOOLCHAIN'):
download_toolchain() download_toolchain()
prepare_nni_node()
compile_ts() compile_ts()
if release: if release or sys.platform == 'win32':
copy_nni_node(release) copy_nni_node(release)
else: else:
symlink_nni_node() symlink_nni_node()
...@@ -48,42 +50,54 @@ def clean(clean_all=False): ...@@ -48,42 +50,54 @@ def clean(clean_all=False):
Remove TypeScript-related intermediate files. Remove TypeScript-related intermediate files.
Python intermediate files are not touched here. Python intermediate files are not touched here.
""" """
clear_nni_node() shutil.rmtree('nni_node', ignore_errors=True)
for path in generated_directories:
shutil.rmtree(path, ignore_errors=True) for file_or_dir in generated_files:
path = Path(file_or_dir)
if path.is_symlink() or path.is_file():
path.unlink()
elif path.is_dir():
shutil.rmtree(path)
if clean_all: if clean_all:
shutil.rmtree('toolchain', ignore_errors=True) shutil.rmtree('toolchain', ignore_errors=True)
Path('nni_node', node_executable).unlink()
if sys.platform == 'linux' or sys.platform == 'darwin': if sys.platform == 'linux' or sys.platform == 'darwin':
node_executable = 'node' node_executable = 'node'
node_spec = f'node-{node_version}-{sys.platform}-x64' node_spec = f'node-{node_version}-{sys.platform}-x64'
node_download_url = f'https://nodejs.org/dist/latest-v10.x/{node_spec}.tar.xz' node_download_url = f'https://nodejs.org/dist/{node_version}/{node_spec}.tar.xz'
node_extractor = lambda data: tarfile.open(fileobj=BytesIO(data), mode='r:xz') node_extractor = lambda data: tarfile.open(fileobj=BytesIO(data), mode='r:xz')
node_executable_in_tarball = 'bin/node' node_executable_in_tarball = 'bin/node'
yarn_executable = 'yarn'
yarn_download_url = f'https://github.com/yarnpkg/yarn/releases/download/{yarn_version}/yarn-{yarn_version}.tar.gz'
path_env_seperator = ':'
elif sys.platform == 'win32': elif sys.platform == 'win32':
node_executable = 'node.exe' node_executable = 'node.exe'
node_spec = f'node-{node_version}-win-x64' node_spec = f'node-{node_version}-win-x64'
node_download_url = f'https://nodejs.org/dist/latest-v10.x/{node_spec}.zip' node_download_url = f'https://nodejs.org/dist/{node_version}/{node_spec}.zip'
node_extractor = lambda data: ZipFile(BytesIO(data)) node_extractor = lambda data: ZipFile(BytesIO(data))
node_executable_in_tarball = 'node.exe' node_executable_in_tarball = 'node.exe'
yarn_executable = 'yarn.cmd'
yarn_download_url = f'https://github.com/yarnpkg/yarn/releases/download/{yarn_version}/yarn-{yarn_version}.tar.gz'
path_env_seperator = ';'
else: else:
raise RuntimeError('Unsupported system') raise RuntimeError('Unsupported system')
yarn_executable = 'yarn' if sys.platform != 'win32' else 'yarn.cmd'
yarn_download_url = f'https://github.com/yarnpkg/yarn/releases/download/{yarn_version}/yarn-{yarn_version}.tar.gz'
def download_toolchain(): def download_toolchain():
""" """
Download and extract node and yarn, Download and extract node and yarn.
then copy node executable to nni_node directory.
""" """
if Path('nni_node', node_executable).is_file(): if Path('toolchain/node', node_executable_in_tarball).is_file():
return return
Path('toolchain').mkdir(exist_ok=True) Path('toolchain').mkdir(exist_ok=True)
import requests # place it here so setup.py can install it before importing import requests # place it here so setup.py can install it before importing
...@@ -105,9 +119,19 @@ def download_toolchain(): ...@@ -105,9 +119,19 @@ def download_toolchain():
shutil.rmtree('toolchain/yarn', ignore_errors=True) shutil.rmtree('toolchain/yarn', ignore_errors=True)
Path(f'toolchain/yarn-{yarn_version}').rename('toolchain/yarn') Path(f'toolchain/yarn-{yarn_version}').rename('toolchain/yarn')
src = Path('toolchain/node', node_executable_in_tarball)
dst = Path('nni_node', node_executable) def prepare_nni_node():
shutil.copyfile(src, dst) """
Create clean nni_node diretory, then copy node runtime to it.
"""
shutil.rmtree('nni_node', ignore_errors=True)
Path('nni_node').mkdir()
Path('nni_node/__init__.py').write_text('"""NNI node.js modules."""\n')
node_src = Path('toolchain/node', node_executable_in_tarball)
node_dst = Path('nni_node', node_executable)
shutil.copyfile(node_src, node_dst)
def compile_ts(): def compile_ts():
...@@ -136,7 +160,6 @@ def symlink_nni_node(): ...@@ -136,7 +160,6 @@ def symlink_nni_node():
If you manually modify and compile TS source files you don't need to install again. If you manually modify and compile TS source files you don't need to install again.
""" """
_print('Creating symlinks') _print('Creating symlinks')
clear_nni_node()
for path in Path('ts/nni_manager/dist').iterdir(): for path in Path('ts/nni_manager/dist').iterdir():
_symlink(path, Path('nni_node', path.name)) _symlink(path, Path('nni_node', path.name))
...@@ -158,7 +181,6 @@ def copy_nni_node(version): ...@@ -158,7 +181,6 @@ def copy_nni_node(version):
while `package.json` in ts directory will be left unchanged. while `package.json` in ts directory will be left unchanged.
""" """
_print('Copying files') _print('Copying files')
clear_nni_node()
# copytree(..., dirs_exist_ok=True) is not supported by Python 3.6 # copytree(..., dirs_exist_ok=True) is not supported by Python 3.6
for path in Path('ts/nni_manager/dist').iterdir(): for path in Path('ts/nni_manager/dist').iterdir():
...@@ -168,7 +190,8 @@ def copy_nni_node(version): ...@@ -168,7 +190,8 @@ def copy_nni_node(version):
shutil.copytree(path, Path('nni_node', path.name)) shutil.copytree(path, Path('nni_node', path.name))
package_json = json.load(open('ts/nni_manager/package.json')) package_json = json.load(open('ts/nni_manager/package.json'))
if version.count('.') == 1: # node.js semver requires at least three parts if version:
while len(version.split('.')) < 3: # node.js semver requires at least three parts
version = version + '.0' version = version + '.0'
package_json['version'] = version package_json['version'] = version
json.dump(package_json, open('nni_node/package.json', 'w'), indent=2) json.dump(package_json, open('nni_node/package.json', 'w'), indent=2)
...@@ -182,28 +205,16 @@ def copy_nni_node(version): ...@@ -182,28 +205,16 @@ def copy_nni_node(version):
shutil.copyfile('ts/nasui/server.js', 'nni_node/nasui/server.js') shutil.copyfile('ts/nasui/server.js', 'nni_node/nasui/server.js')
def clear_nni_node():
"""
Remove compiled files in nni_node.
Use `clean()` if you what to remove files in ts as well.
"""
for path in Path('nni_node').iterdir():
if path.name not in ('__init__.py', 'node', 'node.exe'):
if path.is_symlink() or path.is_file():
path.unlink()
else:
shutil.rmtree(path)
_yarn_env = dict(os.environ) _yarn_env = dict(os.environ)
_yarn_env['PATH'] = str(Path('nni_node').resolve()) + ':' + os.environ['PATH'] # `Path('nni_node').resolve()` does not work on Windows if the directory not exists
_yarn_path = Path('toolchain/yarn/bin', yarn_executable).resolve() _yarn_env['PATH'] = str(Path().resolve() / 'nni_node') + path_env_seperator + os.environ['PATH']
_yarn_path = Path().resolve() / 'toolchain/yarn/bin' / yarn_executable
def _yarn(path, *args): def _yarn(path, *args):
if os.environ.get('GLOBAL_TOOLCHAIN'): if os.environ.get('GLOBAL_TOOLCHAIN'):
subprocess.run(['yarn', *args], cwd=path, check=True) subprocess.run(['yarn', *args], cwd=path, check=True)
else: else:
subprocess.run([_yarn_path, *args], cwd=path, check=True, env=_yarn_env) subprocess.run([str(_yarn_path), *args], cwd=path, check=True, env=_yarn_env)
def _symlink(target_file, link_location): def _symlink(target_file, link_location):
...@@ -214,16 +225,22 @@ def _symlink(target_file, link_location): ...@@ -214,16 +225,22 @@ def _symlink(target_file, link_location):
def _print(*args): def _print(*args):
print('\033[1;36m# ', end='') if sys.platform == 'win32':
print(*args, end='') print(*args)
print('\033[0m') else:
print('\033[1;36m#', *args, '\033[0m')
generated_directories = [ generated_files = [
'ts/nni_manager/dist', 'ts/nni_manager/dist',
'ts/nni_manager/node_modules', 'ts/nni_manager/node_modules',
'ts/webui/build', 'ts/webui/build',
'ts/webui/node_modules', 'ts/webui/node_modules',
'ts/nasui/build', 'ts/nasui/build',
'ts/nasui/node_modules', 'ts/nasui/node_modules',
# unit test
'ts/nni_manager/exp_profile.json',
'ts/nni_manager/metrics.json',
'ts/nni_manager/trial_jobs.json',
] ]
"""
Unit test of NNI Python modules.
Test cases of each module should be placed at same path of their source files.
For example if `nni/tool/annotation` has one test case, it should be placed at `test/ut/tool/annotation.py`;
if it has multiple test cases, they should be placed in `test/ut/tool/annotation/` directory.
"Legacy" test cases carried from NNI v1.x might not follow above convention:
+ Directory `sdk` contains old test cases previously in `src/sdk/pynni/tests`.
+ Directory `tools/cmd` contains old test cases previously in `tools/cmd/tests`.
+ Directory `tools/annotation` contains old test cases previously in `tools/nni_annotation`.
+ Directory `tools/trial_tool` contains old test cases previously in `tools/nni_trial_tool/test`.
"""
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment