Unverified Commit 51d261e7 authored by J-shang's avatar J-shang Committed by GitHub
Browse files

Merge pull request #4668 from microsoft/doc-refactor

parents d63a2ea3 b469e1c1
# Copyright (c) Microsoft Corporation. # Copyright (c) Microsoft Corporation.
# Licensed under the MIT license. # Licensed under the MIT license.
from packaging.version import Version
import torch import torch
import torch.nn as nn import torch.nn as nn
from ...serializer import basic_unit from ...serializer import basic_unit
from ...utils import version_larger_equal
# NOTE: support pytorch version >= 1.5.0 # NOTE: support pytorch version >= 1.5.0
...@@ -31,10 +31,10 @@ __all__ = [ ...@@ -31,10 +31,10 @@ __all__ = [
'Flatten', 'Hardsigmoid' 'Flatten', 'Hardsigmoid'
] ]
if version_larger_equal(torch.__version__, '1.6.0'): if Version(torch.__version__) >= Version('1.6.0'):
__all__.append('Hardswish') __all__.append('Hardswish')
if version_larger_equal(torch.__version__, '1.7.0'): if Version(torch.__version__) >= Version('1.7.0'):
__all__.extend(['Unflatten', 'SiLU', 'TripletMarginWithDistanceLoss']) __all__.extend(['Unflatten', 'SiLU', 'TripletMarginWithDistanceLoss'])
...@@ -149,10 +149,10 @@ Transformer = basic_unit(nn.Transformer) ...@@ -149,10 +149,10 @@ Transformer = basic_unit(nn.Transformer)
Flatten = basic_unit(nn.Flatten) Flatten = basic_unit(nn.Flatten)
Hardsigmoid = basic_unit(nn.Hardsigmoid) Hardsigmoid = basic_unit(nn.Hardsigmoid)
if version_larger_equal(torch.__version__, '1.6.0'): if Version(torch.__version__) >= Version('1.6.0'):
Hardswish = basic_unit(nn.Hardswish) Hardswish = basic_unit(nn.Hardswish)
if version_larger_equal(torch.__version__, '1.7.0'): if Version(torch.__version__) >= Version('1.7.0'):
SiLU = basic_unit(nn.SiLU) SiLU = basic_unit(nn.SiLU)
Unflatten = basic_unit(nn.Unflatten) Unflatten = basic_unit(nn.Unflatten)
TripletMarginWithDistanceLoss = basic_unit(nn.TripletMarginWithDistanceLoss) TripletMarginWithDistanceLoss = basic_unit(nn.TripletMarginWithDistanceLoss)
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
# Licensed under the MIT license. # Licensed under the MIT license.
import inspect import inspect
import os
import warnings import warnings
from typing import Any, TypeVar, Union from typing import Any, TypeVar, Union
...@@ -64,6 +65,12 @@ def basic_unit(cls: T, basic_unit_tag: bool = True) -> Union[T, Traceable]: ...@@ -64,6 +65,12 @@ def basic_unit(cls: T, basic_unit_tag: bool = True) -> Union[T, Traceable]:
class PrimitiveOp(nn.Module): class PrimitiveOp(nn.Module):
... ...
""" """
# Internal flag. See nni.trace
nni_trace_flag = os.environ.get('NNI_TRACE_FLAG', '')
if nni_trace_flag.lower() == 'disable':
return cls
if _check_wrapped(cls, 'basic_unit'): if _check_wrapped(cls, 'basic_unit'):
return cls return cls
...@@ -90,12 +97,18 @@ def model_wrapper(cls: T) -> Union[T, Traceable]: ...@@ -90,12 +97,18 @@ def model_wrapper(cls: T) -> Union[T, Traceable]:
The wrapper serves two purposes: The wrapper serves two purposes:
1. Capture the init parameters of python class so that it can be re-instantiated in another process. 1. Capture the init parameters of python class so that it can be re-instantiated in another process.
2. Reset uid in namespace so that the auto label counting in each model stably starts from zero. 2. Reset uid in namespace so that the auto label counting in each model stably starts from zero.
Currently, NNI might not complain in simple cases where ``@model_wrapper`` is actually not needed. Currently, NNI might not complain in simple cases where ``@model_wrapper`` is actually not needed.
But in future, we might enforce ``@model_wrapper`` to be required for base model. But in future, we might enforce ``@model_wrapper`` to be required for base model.
""" """
# Internal flag. See nni.trace
nni_trace_flag = os.environ.get('NNI_TRACE_FLAG', '')
if nni_trace_flag.lower() == 'disable':
return cls
if _check_wrapped(cls, 'model_wrapper'): if _check_wrapped(cls, 'model_wrapper'):
return cls return cls
......
...@@ -4,6 +4,6 @@ ...@@ -4,6 +4,6 @@
from .base import BaseStrategy from .base import BaseStrategy
from .bruteforce import Random, GridSearch from .bruteforce import Random, GridSearch
from .evolution import RegularizedEvolution from .evolution import RegularizedEvolution
from .tpe_strategy import TPEStrategy from .tpe_strategy import TPEStrategy, TPE
from .local_debug_strategy import _LocalDebugStrategy from .local_debug_strategy import _LocalDebugStrategy
from .rl import PolicyBasedRL from .rl import PolicyBasedRL
...@@ -23,7 +23,7 @@ class PolicyBasedRL(BaseStrategy): ...@@ -23,7 +23,7 @@ class PolicyBasedRL(BaseStrategy):
""" """
Algorithm for policy-based reinforcement learning. Algorithm for policy-based reinforcement learning.
This is a wrapper of algorithms provided in tianshou (PPO by default), This is a wrapper of algorithms provided in tianshou (PPO by default),
and can be easily customized with other algorithms that inherit ``BasePolicy`` (e.g., REINFORCE [1]_). and can be easily customized with other algorithms that inherit ``BasePolicy`` (e.g., REINFORCE :footcite:p:`zoph2017neural`).
Parameters Parameters
---------- ----------
...@@ -34,12 +34,6 @@ class PolicyBasedRL(BaseStrategy): ...@@ -34,12 +34,6 @@ class PolicyBasedRL(BaseStrategy):
After each collect, trainer will sample batch from replay buffer and do the update. Default: 20. After each collect, trainer will sample batch from replay buffer and do the update. Default: 20.
policy_fn : function policy_fn : function
Takes ``ModelEvaluationEnv`` as input and return a policy. See ``_default_policy_fn`` for an example. Takes ``ModelEvaluationEnv`` as input and return a policy. See ``_default_policy_fn`` for an example.
References
----------
.. [1] Barret Zoph and Quoc V. Le, "Neural Architecture Search with Reinforcement Learning".
https://arxiv.org/abs/1611.01578
""" """
def __init__(self, max_collect: int = 100, trial_per_collect = 20, def __init__(self, max_collect: int = 100, trial_per_collect = 20,
......
...@@ -39,17 +39,14 @@ class TPESampler(Sampler): ...@@ -39,17 +39,14 @@ class TPESampler(Sampler):
return chosen return chosen
class TPEStrategy(BaseStrategy): class TPE(BaseStrategy):
""" """
The Tree-structured Parzen Estimator (TPE) [bergstrahpo]_ is a sequential model-based optimization (SMBO) approach. The Tree-structured Parzen Estimator (TPE) is a sequential model-based optimization (SMBO) approach.
SMBO methods sequentially construct models to approximate the performance of hyperparameters based on historical measurements,
and then subsequently choose new hyperparameters to test based on this model.
References Refer to :footcite:t:`bergstra2011algorithms` for details.
----------
.. [bergstrahpo] Bergstra et al., "Algorithms for Hyper-Parameter Optimization". SMBO methods sequentially construct models to approximate the performance of hyperparameters based on historical measurements,
https://papers.nips.cc/paper/4443-algorithms-for-hyper-parameter-optimization.pdf and then subsequently choose new hyperparameters to test based on this model.
""" """
def __init__(self): def __init__(self):
...@@ -92,3 +89,7 @@ class TPEStrategy(BaseStrategy): ...@@ -92,3 +89,7 @@ class TPEStrategy(BaseStrategy):
to_be_deleted.append(_id) to_be_deleted.append(_id)
for _id in to_be_deleted: for _id in to_be_deleted:
del self.running_models[_id] del self.running_models[_id]
# alias for backward compatibility
TPEStrategy = TPE
...@@ -9,6 +9,8 @@ from contextlib import contextmanager ...@@ -9,6 +9,8 @@ from contextlib import contextmanager
from typing import Any, List, Dict from typing import Any, List, Dict
from pathlib import Path from pathlib import Path
__all__ = ['NoContextError', 'ContextStack', 'ModelNamespace']
def import_(target: str, allow_none: bool = False) -> Any: def import_(target: str, allow_none: bool = False) -> Any:
if target is None: if target is None:
...@@ -18,13 +20,6 @@ def import_(target: str, allow_none: bool = False) -> Any: ...@@ -18,13 +20,6 @@ def import_(target: str, allow_none: bool = False) -> Any:
return getattr(module, identifier) return getattr(module, identifier)
def version_larger_equal(a: str, b: str) -> bool:
# TODO: refactor later
a = a.split('+')[0]
b = b.split('+')[0]
return tuple(map(int, a.split('.'))) >= tuple(map(int, b.split('.')))
_last_uid = defaultdict(int) _last_uid = defaultdict(int)
_DEFAULT_MODEL_NAMESPACE = 'model' _DEFAULT_MODEL_NAMESPACE = 'model'
...@@ -72,12 +67,13 @@ def get_importable_name(cls, relocate_module=False): ...@@ -72,12 +67,13 @@ def get_importable_name(cls, relocate_module=False):
class NoContextError(Exception): class NoContextError(Exception):
"""Exception raised when context is missing."""
pass pass
class ContextStack: class ContextStack:
""" """
This is to maintain a globally-accessible context envinronment that is visible to everywhere. This is to maintain a globally-accessible context environment that is visible to everywhere.
Use ``with ContextStack(namespace, value):`` to initiate, and use ``get_current_context(namespace)`` to Use ``with ContextStack(namespace, value):`` to initiate, and use ``get_current_context(namespace)`` to
get the corresponding value in the namespace. get the corresponding value in the namespace.
......
...@@ -61,6 +61,8 @@ def init_logger_for_command_line() -> None: ...@@ -61,6 +61,8 @@ def init_logger_for_command_line() -> None:
_cli_log_initialized = True _cli_log_initialized = True
colorful_formatter = Formatter(log_format, time_format) colorful_formatter = Formatter(log_format, time_format)
colorful_formatter.format = _colorful_format colorful_formatter.format = _colorful_format
if '_default_' not in handlers: # this happens when building sphinx gallery
_register_handler(StreamHandler(sys.stdout), logging.INFO)
handlers['_default_'].setFormatter(colorful_formatter) handlers['_default_'].setFormatter(colorful_formatter)
def start_experiment_log(experiment_id: str, log_directory: Path, debug: bool) -> None: def start_experiment_log(experiment_id: str, log_directory: Path, debug: bool) -> None:
......
...@@ -23,10 +23,10 @@ def get_next_parameter(): ...@@ -23,10 +23,10 @@ def get_next_parameter():
warning_message = ''.join([ warning_message = ''.join([
colorama.Style.BRIGHT, colorama.Style.BRIGHT,
colorama.Fore.RED, colorama.Fore.RED,
'Running NNI code without runtime. ', 'Running trial code without runtime. ',
'Check the following tutorial if you are new to NNI: ', 'Please check the tutorial if you are new to NNI: ',
colorama.Fore.YELLOW, colorama.Fore.YELLOW,
'https://nni.readthedocs.io/en/stable/Tutorial/QuickStart.html#id1', 'https://nni.readthedocs.io/en/stable/tutorials/hpo_quickstart_pytorch/main.html',
colorama.Style.RESET_ALL colorama.Style.RESET_ALL
]) ])
warnings.warn(warning_message, RuntimeWarning) warnings.warn(warning_message, RuntimeWarning)
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
"""
Provide ``nnictl hello`` command to generate quickstart example.
"""
from pathlib import Path
import shutil
from colorama import Fore
import nni_assets
def create_example(_args):
example_path = Path(nni_assets.__path__[0], 'hello_hpo')
try:
shutil.copytree(example_path, 'nni_hello_hpo')
except PermissionError:
print(Fore.RED + 'Permission denied. Please run the command in a writable directory.' + Fore.RESET)
exit(1)
except FileExistsError:
print('File exists. Please run "python nni_hello_hpo/main.py" to start the example.')
exit(1)
print('A hyperparameter optimization example has been created at "nni_hello_hpo" directory.')
print('Please run "python nni_hello_hpo/main.py" to try it out.')
...@@ -16,7 +16,8 @@ from .nnictl_utils import stop_experiment, trial_ls, trial_kill, list_experiment ...@@ -16,7 +16,8 @@ from .nnictl_utils import stop_experiment, trial_ls, trial_kill, list_experiment
save_experiment, load_experiment save_experiment, load_experiment
from .algo_management import algo_reg, algo_unreg, algo_show, algo_list from .algo_management import algo_reg, algo_unreg, algo_show, algo_list
from .constants import DEFAULT_REST_PORT from .constants import DEFAULT_REST_PORT
from .import ts_management from . import hello
from . import ts_management
init(autoreset=True) init(autoreset=True)
...@@ -483,6 +484,10 @@ def get_parser(): ...@@ -483,6 +484,10 @@ def get_parser():
jupyter_uninstall_parser = jupyter_subparsers.add_parser('uninstall', description='Uninstall JupyterLab extension.') jupyter_uninstall_parser = jupyter_subparsers.add_parser('uninstall', description='Uninstall JupyterLab extension.')
jupyter_uninstall_parser.set_defaults(func=_jupyter_uninstall) jupyter_uninstall_parser.set_defaults(func=_jupyter_uninstall)
# hello command
parser_hello = subparsers.add_parser('hello', description='Create "hello nni" example in current directory.')
parser_hello.set_defaults(func=hello.create_example)
return parser return parser
......
...@@ -105,7 +105,7 @@ def parse_ids(args): ...@@ -105,7 +105,7 @@ def parse_ids(args):
3.If there is an id specified, return the corresponding id 3.If there is an id specified, return the corresponding id
4.If there is no id specified, and there is an experiment running, return the id, or return Error 4.If there is no id specified, and there is an experiment running, return the id, or return Error
5.If the id matches an experiment, nnictl will return the id. 5.If the id matches an experiment, nnictl will return the id.
6.If the id ends with *, nnictl will match all ids matchs the regular 6.If the id ends with ``*``, nnictl will match all ids matchs the regular
7.If the id does not exist but match the prefix of an experiment id, nnictl will return the matched id 7.If the id does not exist but match the prefix of an experiment id, nnictl will return the matched id
8.If the id does not exist but match multiple prefix of the experiment ids, nnictl will give id information 8.If the id does not exist but match multiple prefix of the experiment ids, nnictl will give id information
''' '''
......
...@@ -25,16 +25,35 @@ _sequence_id = platform.get_sequence_id() ...@@ -25,16 +25,35 @@ _sequence_id = platform.get_sequence_id()
def get_next_parameter(): def get_next_parameter():
""" """
Get the hyper paremeters generated by tuner. For a multiphase experiment, it returns a new group of hyper Get the hyperparameters generated by tuner.
parameters at each call of get_next_parameter. For a non-multiphase (multiPhase is not configured or set to False)
experiment, it returns hyper parameters only on the first call for each trial job, it returns None since second call. Each trial should and should only invoke this function once.
This API should be called only once in each trial job of an experiment which is not specified as multiphase. Otherwise the behavior is undefined.
Examples
--------
Assuming the search space is:
.. code-block::
{
'activation': {'_type': 'choice', '_value': ['relu', 'tanh', 'sigmoid']},
'learning_rate': {'_type': 'loguniform', '_value': [0.0001, 0.1]}
}
Then this function might return:
.. code-block::
{
'activation': 'relu',
'learning_rate': 0.02
}
Returns Returns
------- -------
dict dict
A dict object contains the hyper parameters generated by tuner, the keys of the dict are defined in A hyperparameter set sampled from search space.
search space. Returns None if no more hyper parameters can be generated by tuner.
""" """
global _params global _params
_params = platform.get_next_parameter() _params = platform.get_next_parameter()
...@@ -43,15 +62,6 @@ def get_next_parameter(): ...@@ -43,15 +62,6 @@ def get_next_parameter():
return _params['parameters'] return _params['parameters']
def get_current_parameter(tag=None): def get_current_parameter(tag=None):
"""
Get current hyper parameters generated by tuner. It returns the same group of hyper parameters as the last
call of get_next_parameter returns.
Parameters
----------
tag: str
hyper parameter key
"""
global _params global _params
if _params is None: if _params is None:
return None return None
...@@ -59,39 +69,25 @@ def get_current_parameter(tag=None): ...@@ -59,39 +69,25 @@ def get_current_parameter(tag=None):
return _params['parameters'] return _params['parameters']
return _params['parameters'][tag] return _params['parameters'][tag]
def get_experiment_id(): def get_experiment_id() -> str:
""" """
Get experiment ID. Return experiment ID.
Returns
-------
str
Identifier of current experiment
""" """
return _experiment_id return _experiment_id
def get_trial_id(): def get_trial_id() -> str:
""" """
Get trial job ID which is string identifier of a trial job, for example 'MoXrp'. In one experiment, each trial Return unique ID of the trial that is current running.
job has an unique string ID.
Returns This is shown as "ID" in the web portal's trial table.
-------
str
Identifier of current trial job which is calling this API.
""" """
return _trial_id return _trial_id
def get_sequence_id(): def get_sequence_id() -> int:
""" """
Get trial job sequence nubmer. A sequence number is an integer value assigned to each trial job base on the Return sequence nubmer of the trial that is currently running.
order they are submitted, incremental starting from 0. In one experiment, both trial job ID and sequence number
are unique for each trial job, they are of different data types.
Returns This is shown as "Trial No." in the web portal's trial table.
-------
int
Sequence number of current trial job which is calling this API.
""" """
return _sequence_id return _sequence_id
...@@ -99,14 +95,6 @@ _intermediate_seq = 0 ...@@ -99,14 +95,6 @@ _intermediate_seq = 0
def overwrite_intermediate_seq(value): def overwrite_intermediate_seq(value):
"""
Overwrite intermediate sequence value.
Parameters
----------
value:
int
"""
assert isinstance(value, int) assert isinstance(value, int)
global _intermediate_seq global _intermediate_seq
_intermediate_seq = value _intermediate_seq = value
...@@ -116,10 +104,12 @@ def report_intermediate_result(metric): ...@@ -116,10 +104,12 @@ def report_intermediate_result(metric):
""" """
Reports intermediate result to NNI. Reports intermediate result to NNI.
Parameters ``metric`` should either be a float, or a dict that ``metric['default']`` is a float.
----------
metric: If ``metric`` is a dict, ``metric['default']`` will be used by tuner,
serializable object. and other items can be visualized with web portal.
Typically ``metric`` is per-epoch accuracy or loss.
""" """
global _intermediate_seq global _intermediate_seq
assert _params or trial_env_vars.NNI_PLATFORM is None, \ assert _params or trial_env_vars.NNI_PLATFORM is None, \
...@@ -138,11 +128,12 @@ def report_final_result(metric): ...@@ -138,11 +128,12 @@ def report_final_result(metric):
""" """
Reports final result to NNI. Reports final result to NNI.
Parameters ``metric`` should either be a float, or a dict that ``metric['default']`` is a float.
----------
metric: serializable object If ``metric`` is a dict, ``metric['default']`` will be used by tuner,
Usually (for built-in tuners to work), it should be a number, or and other items can be visualized with web portal.
a dict with key "default" (a number), and any other extra keys.
Typically ``metric`` is the final accuracy or loss.
""" """
assert _params or trial_env_vars.NNI_PLATFORM is None, \ assert _params or trial_env_vars.NNI_PLATFORM is None, \
'nni.get_next_parameter() needs to be called before report_final_result' 'nni.get_next_parameter() needs to be called before report_final_result'
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import sys
import typing
if typing.TYPE_CHECKING or sys.version_info >= (3, 8):
Literal = typing.Literal
else:
Literal = typing.Any
"""
NNI hyperparameter optimization example.
Check the online tutorial for details:
https://nni.readthedocs.io/en/stable/tutorials/hpo_quickstart_pytorch/main.html
"""
from pathlib import Path
import signal
from nni.experiment import Experiment
# Define search space
search_space = {
'features': {'_type': 'choice', '_value': [128, 256, 512, 1024]},
'lr': {'_type': 'loguniform', '_value': [0.0001, 0.1]},
'momentum': {'_type': 'uniform', '_value': [0, 1]},
}
# Configure experiment
experiment = Experiment('local')
experiment.config.trial_command = 'python model.py'
experiment.config.trial_code_directory = Path(__file__).parent
experiment.config.search_space = search_space
experiment.config.tuner.name = 'Random'
experiment.config.max_trial_number = 10
experiment.config.trial_concurrency = 2
# Run it!
experiment.run(port=8080, wait_completion=False)
print('Experiment is running. Press Ctrl-C to quit.')
signal.pause()
"""
Run main.py to start.
This script is modified from PyTorch quickstart:
https://pytorch.org/tutorials/beginner/basics/quickstart_tutorial.html
"""
import nni
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor
# Get optimized hyperparameters
params = {'features': 512, 'lr': 0.001, 'momentum': 0}
optimized_params = nni.get_next_parameter()
params.update(optimized_params)
# Load dataset
training_data = datasets.FashionMNIST(root='data', train=True, download=True, transform=ToTensor())
test_data = datasets.FashionMNIST(root='data', train=False, download=True, transform=ToTensor())
train_dataloader = DataLoader(training_data, batch_size=64)
test_dataloader = DataLoader(test_data, batch_size=64)
# Build model
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = nn.Sequential(
nn.Flatten(),
nn.Linear(28*28, params['features']),
nn.ReLU(),
nn.Linear(params['features'], params['features']),
nn.ReLU(),
nn.Linear(params['features'], 10)
).to(device)
# Training functions
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=params['lr'], momentum=params['momentum'])
def train(dataloader, model, loss_fn, optimizer):
model.train()
for batch, (X, y) in enumerate(dataloader):
X, y = X.to(device), y.to(device)
pred = model(X)
loss = loss_fn(pred, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
def test(dataloader, model, loss_fn):
model.eval()
correct = 0
with torch.no_grad():
for X, y in dataloader:
X, y = X.to(device), y.to(device)
pred = model(X)
correct += (pred.argmax(1) == y).type(torch.float).sum().item()
return correct / len(dataloader.dataset)
# Train the model
epochs = 5
for t in range(epochs):
train(train_dataloader, model, loss_fn, optimizer)
accuracy = test(test_dataloader, model, loss_fn)
nni.report_intermediate_result(accuracy)
nni.report_final_result(accuracy)
...@@ -27,7 +27,7 @@ Uninstall: ...@@ -27,7 +27,7 @@ Uninstall:
$ pip uninstall nni $ pip uninstall nni
Remove generated files: (use "--all" to remove toolchain and built wheel) Remove generated files: (use "--all" to remove built wheel)
$ python setup.py clean [--all] $ python setup.py clean [--all]
...@@ -112,6 +112,7 @@ def _setup(): ...@@ -112,6 +112,7 @@ def _setup():
packages = _find_python_packages(), packages = _find_python_packages(),
package_data = { package_data = {
'nni': _find_requirements_txt() + _find_default_config(), # setuptools issue #1806 'nni': _find_requirements_txt() + _find_default_config(), # setuptools issue #1806
'nni_assets': _find_asset_files(),
'nni_node': _find_node_files() # note: this does not work before building 'nni_node': _find_node_files() # note: this does not work before building
}, },
...@@ -124,6 +125,7 @@ def _setup(): ...@@ -124,6 +125,7 @@ def _setup():
'BOHB': _read_requirements_txt('dependencies/required_extra.txt', 'BOHB'), 'BOHB': _read_requirements_txt('dependencies/required_extra.txt', 'BOHB'),
'PPOTuner': _read_requirements_txt('dependencies/required_extra.txt', 'PPOTuner'), 'PPOTuner': _read_requirements_txt('dependencies/required_extra.txt', 'PPOTuner'),
'DNGO': _read_requirements_txt('dependencies/required_extra.txt', 'DNGO'), 'DNGO': _read_requirements_txt('dependencies/required_extra.txt', 'DNGO'),
'all': _read_requirements_txt('dependencies/required_extra.txt'),
}, },
setup_requires = ['requests'], setup_requires = ['requests'],
...@@ -165,6 +167,14 @@ def _find_requirements_txt(): ...@@ -165,6 +167,14 @@ def _find_requirements_txt():
def _find_default_config(): def _find_default_config():
return ['runtime/default_config/' + name for name in os.listdir('nni/runtime/default_config')] return ['runtime/default_config/' + name for name in os.listdir('nni/runtime/default_config')]
def _find_asset_files():
files = []
for dirpath, dirnames, filenames in os.walk('nni_assets'):
for filename in filenames:
if os.path.splitext(filename)[1] == '.py':
files.append(os.path.join(dirpath[len('nni_assets/'):], filename))
return sorted(files)
def _find_node_files(): def _find_node_files():
if not os.path.exists('nni_node'): if not os.path.exists('nni_node'):
if release and 'build_ts' not in sys.argv and 'clean' not in sys.argv: if release and 'build_ts' not in sys.argv and 'clean' not in sys.argv:
...@@ -221,7 +231,7 @@ class Build(build): ...@@ -221,7 +231,7 @@ class Build(build):
check_jupyter_lab_version() check_jupyter_lab_version()
if os.path.islink('nni_node/main.js'): if os.path.islink('nni_node/main.js'):
sys.exit('A development build already exists. Please uninstall NNI and run "python3 setup.py clean --all".') sys.exit('A development build already exists. Please uninstall NNI and run "python3 setup.py clean".')
open('nni/version.py', 'w').write(f"__version__ = '{release}'") open('nni/version.py', 'w').write(f"__version__ = '{release}'")
super().run() super().run()
...@@ -259,7 +269,7 @@ class Clean(clean): ...@@ -259,7 +269,7 @@ class Clean(clean):
def run(self): def run(self):
super().run() super().run()
setup_ts.clean(self._all) setup_ts.clean()
_clean_temp_files() _clean_temp_files()
shutil.rmtree('nni.egg-info', ignore_errors=True) shutil.rmtree('nni.egg-info', ignore_errors=True)
if self._all: if self._all:
...@@ -279,7 +289,10 @@ _temp_files = [ ...@@ -279,7 +289,10 @@ _temp_files = [
'test/model_path/', 'test/model_path/',
'test/temp.json', 'test/temp.json',
'test/ut/sdk/*.pth', 'test/ut/sdk/*.pth',
'test/ut/tools/annotation/_generated/' 'test/ut/tools/annotation/_generated/',
# example
'nni_assets/**/data/',
] ]
......
...@@ -15,6 +15,7 @@ from io import BytesIO ...@@ -15,6 +15,7 @@ from io import BytesIO
import json import json
import os import os
from pathlib import Path from pathlib import Path
import platform
import shutil import shutil
import subprocess import subprocess
import sys import sys
...@@ -56,12 +57,13 @@ def build(release): ...@@ -56,12 +57,13 @@ def build(release):
symlink_nni_node() symlink_nni_node()
restore_package() restore_package()
def clean(clean_all=False): def clean():
""" """
Remove TypeScript-related intermediate files. Remove TypeScript-related intermediate files.
Python intermediate files are not touched here. Python intermediate files are not touched here.
""" """
shutil.rmtree('nni_node', ignore_errors=True) shutil.rmtree('nni_node', ignore_errors=True)
shutil.rmtree('toolchain', ignore_errors=True)
for file_or_dir in generated_files: for file_or_dir in generated_files:
path = Path(file_or_dir) path = Path(file_or_dir)
...@@ -70,13 +72,11 @@ def clean(clean_all=False): ...@@ -70,13 +72,11 @@ def clean(clean_all=False):
elif path.is_dir(): elif path.is_dir():
shutil.rmtree(path) shutil.rmtree(path)
if clean_all:
shutil.rmtree('toolchain', ignore_errors=True)
if sys.platform == 'linux' or sys.platform == 'darwin': if sys.platform == 'linux' or sys.platform == 'darwin':
node_executable = 'node' node_executable = 'node'
node_spec = f'node-{node_version}-{sys.platform}-x64' _arch = 'x64' if platform.machine() == 'x86_64' else platform.machine()
node_spec = f'node-{node_version}-{sys.platform}-' + _arch
node_download_url = f'https://nodejs.org/dist/{node_version}/{node_spec}.tar.xz' node_download_url = f'https://nodejs.org/dist/{node_version}/{node_spec}.tar.xz'
node_extractor = lambda data: tarfile.open(fileobj=BytesIO(data), mode='r:xz') node_extractor = lambda data: tarfile.open(fileobj=BytesIO(data), mode='r:xz')
node_executable_in_tarball = 'bin/node' node_executable_in_tarball = 'bin/node'
...@@ -183,16 +183,14 @@ def compile_ts(release): ...@@ -183,16 +183,14 @@ def compile_ts(release):
_yarn('ts/webui', 'build') _yarn('ts/webui', 'build')
_print('Building JupyterLab extension') _print('Building JupyterLab extension')
if release: try:
_yarn('ts/jupyter_extension') _yarn('ts/jupyter_extension')
_yarn('ts/jupyter_extension', 'build') _yarn('ts/jupyter_extension', 'build')
else: except Exception:
try: if release:
_yarn('ts/jupyter_extension') raise
_yarn('ts/jupyter_extension', 'build') _print('Failed to build JupyterLab extension, skip for develop mode', color='yellow')
except Exception: _print(traceback.format_exc(), color='yellow')
_print('Failed to build JupyterLab extension, skip for develop mode', color='yellow')
_print(traceback.format_exc(), color='yellow')
def symlink_nni_node(): def symlink_nni_node():
...@@ -225,12 +223,9 @@ def copy_nni_node(version): ...@@ -225,12 +223,9 @@ def copy_nni_node(version):
""" """
_print('Copying files') _print('Copying files')
# copytree(..., dirs_exist_ok=True) is not supported by Python 3.6 shutil.copytree('ts/nni_manager/dist', 'nni_node', dirs_exist_ok=True)
for path in Path('ts/nni_manager/dist').iterdir(): shutil.copyfile('ts/nni_manager/yarn.lock', 'nni_node/yarn.lock')
if path.is_dir(): Path('nni_node/nni_manager.tsbuildinfo').unlink()
shutil.copytree(path, Path('nni_node', path.name))
elif path.name != 'nni_manager.tsbuildinfo':
shutil.copyfile(path, Path('nni_node', path.name))
package_json = json.load(open('ts/nni_manager/package.json')) package_json = json.load(open('ts/nni_manager/package.json'))
if version: if version:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment