Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
349ead41
".circleci/unittest/windows/vscode:/vscode.git/clone" did not exist on "fa1aa52dda3b8d7bca48223addd7a8c254bd1d3e"
Commit
349ead41
authored
Jan 14, 2021
by
liuzhe
Browse files
Merge branch 'v2.0' into master
parents
25db55ca
649ee597
Changes
291
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
369 additions
and
234 deletions
+369
-234
nni/retiarii/trainer/__init__.py
nni/retiarii/trainer/__init__.py
+1
-1
nni/retiarii/trainer/pytorch/base.py
nni/retiarii/trainer/pytorch/base.py
+2
-2
nni/retiarii/trainer/pytorch/darts.py
nni/retiarii/trainer/pytorch/darts.py
+14
-3
nni/retiarii/trainer/pytorch/proxyless.py
nni/retiarii/trainer/pytorch/proxyless.py
+1
-0
nni/retiarii/trainer/pytorch/utils.py
nni/retiarii/trainer/pytorch/utils.py
+3
-2
nni/retiarii/utils.py
nni/retiarii/utils.py
+75
-48
nni/runtime/log.py
nni/runtime/log.py
+29
-25
nni/runtime/platform/__init__.py
nni/runtime/platform/__init__.py
+1
-1
nni/tools/nnictl/config_schema.py
nni/tools/nnictl/config_schema.py
+10
-10
nni/tools/nnictl/launcher.py
nni/tools/nnictl/launcher.py
+43
-38
nni/tools/nnictl/launcher_utils.py
nni/tools/nnictl/launcher_utils.py
+3
-1
nni/tools/nnictl/nnictl.py
nni/tools/nnictl/nnictl.py
+6
-6
nni/tools/nnictl/nnictl_utils.py
nni/tools/nnictl/nnictl_utils.py
+39
-34
nni/tools/nnictl/tensorboard_utils.py
nni/tools/nnictl/tensorboard_utils.py
+7
-8
nni/tools/nnictl/updater.py
nni/tools/nnictl/updater.py
+4
-3
pipelines/fast-test.yml
pipelines/fast-test.yml
+28
-21
pipelines/integration-test-adl.yml
pipelines/integration-test-adl.yml
+63
-0
pipelines/release.yml
pipelines/release.yml
+33
-23
setup.py
setup.py
+4
-6
setup_ts.py
setup_ts.py
+3
-2
No files found.
nni/retiarii/trainer/__init__.py
View file @
349ead41
from
.interface
import
BaseTrainer
from
.interface
import
BaseTrainer
,
BaseOneShotTrainer
from
.pytorch
import
PyTorchImageClassificationTrainer
,
PyTorchMultiModelTrainer
from
.pytorch
import
PyTorchImageClassificationTrainer
,
PyTorchMultiModelTrainer
nni/retiarii/trainer/pytorch/base.py
View file @
349ead41
...
@@ -43,7 +43,7 @@ def get_default_transform(dataset: str) -> Any:
...
@@ -43,7 +43,7 @@ def get_default_transform(dataset: str) -> Any:
return
None
return
None
@
register_trainer
()
@
register_trainer
class
PyTorchImageClassificationTrainer
(
BaseTrainer
):
class
PyTorchImageClassificationTrainer
(
BaseTrainer
):
"""
"""
Image classification trainer for PyTorch.
Image classification trainer for PyTorch.
...
@@ -80,7 +80,7 @@ class PyTorchImageClassificationTrainer(BaseTrainer):
...
@@ -80,7 +80,7 @@ class PyTorchImageClassificationTrainer(BaseTrainer):
Keyword arguments passed to trainer. Will be passed to Trainer class in future. Currently,
Keyword arguments passed to trainer. Will be passed to Trainer class in future. Currently,
only the key ``max_epochs`` is useful.
only the key ``max_epochs`` is useful.
"""
"""
super
(
PyTorchImageClassificationTrainer
,
self
).
__init__
()
super
().
__init__
()
self
.
_use_cuda
=
torch
.
cuda
.
is_available
()
self
.
_use_cuda
=
torch
.
cuda
.
is_available
()
self
.
model
=
model
self
.
model
=
model
if
self
.
_use_cuda
:
if
self
.
_use_cuda
:
...
...
nni/retiarii/trainer/pytorch/darts.py
View file @
349ead41
...
@@ -6,6 +6,7 @@ import logging
...
@@ -6,6 +6,7 @@ import logging
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
..interface
import
BaseOneShotTrainer
from
..interface
import
BaseOneShotTrainer
from
.utils
import
AverageMeterGroup
,
replace_layer_choice
,
replace_input_choice
from
.utils
import
AverageMeterGroup
,
replace_layer_choice
,
replace_input_choice
...
@@ -17,13 +18,14 @@ _logger = logging.getLogger(__name__)
...
@@ -17,13 +18,14 @@ _logger = logging.getLogger(__name__)
class
DartsLayerChoice
(
nn
.
Module
):
class
DartsLayerChoice
(
nn
.
Module
):
def
__init__
(
self
,
layer_choice
):
def
__init__
(
self
,
layer_choice
):
super
(
DartsLayerChoice
,
self
).
__init__
()
super
(
DartsLayerChoice
,
self
).
__init__
()
self
.
name
=
layer_choice
.
key
self
.
op_choices
=
nn
.
ModuleDict
(
layer_choice
.
named_children
())
self
.
op_choices
=
nn
.
ModuleDict
(
layer_choice
.
named_children
())
self
.
alpha
=
nn
.
Parameter
(
torch
.
randn
(
len
(
self
.
op_choices
))
*
1e-3
)
self
.
alpha
=
nn
.
Parameter
(
torch
.
randn
(
len
(
self
.
op_choices
))
*
1e-3
)
def
forward
(
self
,
*
args
,
**
kwargs
):
def
forward
(
self
,
*
args
,
**
kwargs
):
op_results
=
torch
.
stack
([
op
(
*
args
,
**
kwargs
)
for
op
in
self
.
op_choices
.
values
()])
op_results
=
torch
.
stack
([
op
(
*
args
,
**
kwargs
)
for
op
in
self
.
op_choices
.
values
()])
alpha_shape
=
[
-
1
]
+
[
1
]
*
(
len
(
op_results
.
size
())
-
1
)
alpha_shape
=
[
-
1
]
+
[
1
]
*
(
len
(
op_results
.
size
())
-
1
)
return
torch
.
sum
(
op_results
*
self
.
alpha
.
view
(
*
alpha_shape
),
0
)
return
torch
.
sum
(
op_results
*
F
.
softmax
(
self
.
alpha
,
-
1
)
.
view
(
*
alpha_shape
),
0
)
def
parameters
(
self
):
def
parameters
(
self
):
for
_
,
p
in
self
.
named_parameters
():
for
_
,
p
in
self
.
named_parameters
():
...
@@ -42,13 +44,14 @@ class DartsLayerChoice(nn.Module):
...
@@ -42,13 +44,14 @@ class DartsLayerChoice(nn.Module):
class
DartsInputChoice
(
nn
.
Module
):
class
DartsInputChoice
(
nn
.
Module
):
def
__init__
(
self
,
input_choice
):
def
__init__
(
self
,
input_choice
):
super
(
DartsInputChoice
,
self
).
__init__
()
super
(
DartsInputChoice
,
self
).
__init__
()
self
.
name
=
input_choice
.
key
self
.
alpha
=
nn
.
Parameter
(
torch
.
randn
(
input_choice
.
n_candidates
)
*
1e-3
)
self
.
alpha
=
nn
.
Parameter
(
torch
.
randn
(
input_choice
.
n_candidates
)
*
1e-3
)
self
.
n_chosen
=
input_choice
.
n_chosen
or
1
self
.
n_chosen
=
input_choice
.
n_chosen
or
1
def
forward
(
self
,
inputs
):
def
forward
(
self
,
inputs
):
inputs
=
torch
.
stack
(
inputs
)
inputs
=
torch
.
stack
(
inputs
)
alpha_shape
=
[
-
1
]
+
[
1
]
*
(
len
(
inputs
.
size
())
-
1
)
alpha_shape
=
[
-
1
]
+
[
1
]
*
(
len
(
inputs
.
size
())
-
1
)
return
torch
.
sum
(
inputs
*
self
.
alpha
.
view
(
*
alpha_shape
),
0
)
return
torch
.
sum
(
inputs
*
F
.
softmax
(
self
.
alpha
,
-
1
)
.
view
(
*
alpha_shape
),
0
)
def
parameters
(
self
):
def
parameters
(
self
):
for
_
,
p
in
self
.
named_parameters
():
for
_
,
p
in
self
.
named_parameters
():
...
@@ -123,7 +126,15 @@ class DartsTrainer(BaseOneShotTrainer):
...
@@ -123,7 +126,15 @@ class DartsTrainer(BaseOneShotTrainer):
module
.
to
(
self
.
device
)
module
.
to
(
self
.
device
)
self
.
model_optim
=
optimizer
self
.
model_optim
=
optimizer
self
.
ctrl_optim
=
torch
.
optim
.
Adam
([
m
.
alpha
for
_
,
m
in
self
.
nas_modules
],
arc_learning_rate
,
betas
=
(
0.5
,
0.999
),
# use the same architecture weight for modules with duplicated names
ctrl_params
=
{}
for
_
,
m
in
self
.
nas_modules
:
if
m
.
name
in
ctrl_params
:
assert
m
.
alpha
.
size
()
==
ctrl_params
[
m
.
name
].
size
(),
'Size of parameters with the same label should be same.'
m
.
alpha
=
ctrl_params
[
m
.
name
]
else
:
ctrl_params
[
m
.
name
]
=
m
.
alpha
self
.
ctrl_optim
=
torch
.
optim
.
Adam
(
list
(
ctrl_params
.
values
()),
arc_learning_rate
,
betas
=
(
0.5
,
0.999
),
weight_decay
=
1.0E-3
)
weight_decay
=
1.0E-3
)
self
.
unrolled
=
unrolled
self
.
unrolled
=
unrolled
self
.
grad_clip
=
5.
self
.
grad_clip
=
5.
...
...
nni/retiarii/trainer/pytorch/proxyless.py
View file @
349ead41
...
@@ -157,6 +157,7 @@ class ProxylessTrainer(BaseOneShotTrainer):
...
@@ -157,6 +157,7 @@ class ProxylessTrainer(BaseOneShotTrainer):
module
.
to
(
self
.
device
)
module
.
to
(
self
.
device
)
self
.
optimizer
=
optimizer
self
.
optimizer
=
optimizer
# we do not support deduplicate control parameters with same label (like DARTS) yet.
self
.
ctrl_optim
=
torch
.
optim
.
Adam
([
m
.
alpha
for
_
,
m
in
self
.
nas_modules
],
arc_learning_rate
,
self
.
ctrl_optim
=
torch
.
optim
.
Adam
([
m
.
alpha
for
_
,
m
in
self
.
nas_modules
],
arc_learning_rate
,
weight_decay
=
0
,
betas
=
(
0
,
0.999
),
eps
=
1e-8
)
weight_decay
=
0
,
betas
=
(
0
,
0.999
),
eps
=
1e-8
)
self
.
_init_dataloader
()
self
.
_init_dataloader
()
...
...
nni/retiarii/trainer/pytorch/utils.py
View file @
349ead41
...
@@ -6,6 +6,7 @@ from collections import OrderedDict
...
@@ -6,6 +6,7 @@ from collections import OrderedDict
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
import
nni.retiarii.nn.pytorch
as
nn
from
nni.nas.pytorch.mutables
import
InputChoice
,
LayerChoice
from
nni.nas.pytorch.mutables
import
InputChoice
,
LayerChoice
_logger
=
logging
.
getLogger
(
__name__
)
_logger
=
logging
.
getLogger
(
__name__
)
...
@@ -157,7 +158,7 @@ def replace_layer_choice(root_module, init_fn, modules=None):
...
@@ -157,7 +158,7 @@ def replace_layer_choice(root_module, init_fn, modules=None):
List[Tuple[str, nn.Module]]
List[Tuple[str, nn.Module]]
A list from layer choice keys (names) and replaced modules.
A list from layer choice keys (names) and replaced modules.
"""
"""
return
_replace_module_with_type
(
root_module
,
init_fn
,
LayerChoice
,
modules
)
return
_replace_module_with_type
(
root_module
,
init_fn
,
(
LayerChoice
,
nn
.
LayerChoice
),
modules
)
def
replace_input_choice
(
root_module
,
init_fn
,
modules
=
None
):
def
replace_input_choice
(
root_module
,
init_fn
,
modules
=
None
):
...
@@ -178,4 +179,4 @@ def replace_input_choice(root_module, init_fn, modules=None):
...
@@ -178,4 +179,4 @@ def replace_input_choice(root_module, init_fn, modules=None):
List[Tuple[str, nn.Module]]
List[Tuple[str, nn.Module]]
A list from layer choice keys (names) and replaced modules.
A list from layer choice keys (names) and replaced modules.
"""
"""
return
_replace_module_with_type
(
root_module
,
init_fn
,
InputChoice
,
modules
)
return
_replace_module_with_type
(
root_module
,
init_fn
,
(
InputChoice
,
nn
.
InputChoice
),
modules
)
nni/retiarii/utils.py
View file @
349ead41
import
inspect
import
inspect
import
warnings
from
collections
import
defaultdict
from
collections
import
defaultdict
from
typing
import
Any
from
typing
import
Any
...
@@ -11,6 +12,13 @@ def import_(target: str, allow_none: bool = False) -> Any:
...
@@ -11,6 +12,13 @@ def import_(target: str, allow_none: bool = False) -> Any:
return
getattr
(
module
,
identifier
)
return
getattr
(
module
,
identifier
)
def
version_larger_equal
(
a
:
str
,
b
:
str
)
->
bool
:
# TODO: refactor later
a
=
a
.
split
(
'+'
)[
0
]
b
=
b
.
split
(
'+'
)[
0
]
return
tuple
(
map
(
int
,
a
.
split
(
'.'
)))
>=
tuple
(
map
(
int
,
b
.
split
(
'.'
)))
_records
=
{}
_records
=
{}
...
@@ -19,6 +27,11 @@ def get_records():
...
@@ -19,6 +27,11 @@ def get_records():
return
_records
return
_records
def
clear_records
():
global
_records
_records
=
{}
def
add_record
(
key
,
value
):
def
add_record
(
key
,
value
):
"""
"""
"""
"""
...
@@ -28,69 +41,83 @@ def add_record(key, value):
...
@@ -28,69 +41,83 @@ def add_record(key, value):
_records
[
key
]
=
value
_records
[
key
]
=
value
def
_register_module
(
original_class
):
def
del_record
(
key
):
orig_init
=
original_class
.
__init__
global
_records
argname_list
=
list
(
inspect
.
signature
(
original_class
).
parameters
.
keys
())
if
_records
is
not
None
:
# Make copy of original __init__, so we can call it without recursion
_records
.
pop
(
key
,
None
)
def
__init__
(
self
,
*
args
,
**
kws
):
full_args
=
{}
full_args
.
update
(
kws
)
for
i
,
arg
in
enumerate
(
args
):
full_args
[
argname_list
[
i
]]
=
arg
add_record
(
id
(
self
),
full_args
)
orig_init
(
self
,
*
args
,
**
kws
)
# Call the original __init__
def
_blackbox_cls
(
cls
,
module_name
,
register_format
=
None
):
class
wrapper
(
cls
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
argname_list
=
list
(
inspect
.
signature
(
cls
).
parameters
.
keys
())
full_args
=
{}
full_args
.
update
(
kwargs
)
original_class
.
__init__
=
__init__
# Set the class' __init__ to the new one
assert
len
(
args
)
<=
len
(
argname_list
),
f
'Length of
{
args
}
is greater than length of
{
argname_list
}
.'
return
original_class
for
argname
,
value
in
zip
(
argname_list
,
args
):
full_args
[
argname
]
=
value
# eject un-serializable arguments
for
k
in
list
(
full_args
.
keys
()):
# The list is not complete and does not support nested cases.
if
not
isinstance
(
full_args
[
k
],
(
int
,
float
,
str
,
dict
,
list
,
tuple
)):
if
not
(
register_format
==
'full'
and
k
==
'model'
):
# no warning if it is base model in trainer
warnings
.
warn
(
f
'
{
cls
}
has un-serializable arguments
{
k
}
whose value is
{
full_args
[
k
]
}
.
\
This is not supported. You can ignore this warning if you are passing the model to trainer.'
)
full_args
.
pop
(
k
)
def
register_module
():
if
register_format
==
'args'
:
"""
add_record
(
id
(
self
),
full_args
)
Register a module.
elif
register_format
==
'full'
:
"""
full_class_name
=
cls
.
__module__
+
'.'
+
cls
.
__name__
# use it as a decorator: @register_module()
add_record
(
id
(
self
),
{
'modulename'
:
full_class_name
,
'args'
:
full_args
})
def
_register
(
cls
):
m
=
_register_module
(
super
().
__init__
(
*
args
,
**
kwargs
)
original_class
=
cls
)
return
m
return
_register
def
__del__
(
self
):
del_record
(
id
(
self
))
# using module_name instead of cls.__module__ because it's more natural to see where the module gets wrapped
# instead of simply putting torch.nn or etc.
wrapper
.
__module__
=
module_name
wrapper
.
__name__
=
cls
.
__name__
wrapper
.
__qualname__
=
cls
.
__qualname__
wrapper
.
__init__
.
__doc__
=
cls
.
__init__
.
__doc__
def
_register_trainer
(
original_class
):
return
wrapper
orig_init
=
original_class
.
__init__
argname_list
=
list
(
inspect
.
signature
(
original_class
).
parameters
.
keys
())
# Make copy of original __init__, so we can call it without recursion
full_class_name
=
original_class
.
__module__
+
'.'
+
original_class
.
__name__
def
__init__
(
self
,
*
args
,
**
kws
):
def
blackbox
(
cls
,
*
args
,
**
kwargs
):
full_args
=
{}
"""
full_args
.
update
(
kws
)
To create an blackbox instance inline without decorator. For example,
for
i
,
arg
in
enumerate
(
args
):
# TODO: support both pytorch and tensorflow
from
.nn.pytorch
import
Module
if
isinstance
(
args
[
i
],
Module
):
# ignore the base model object
continue
full_args
[
argname_list
[
i
]]
=
arg
add_record
(
id
(
self
),
{
'modulename'
:
full_class_name
,
'args'
:
full_args
})
orig_init
(
self
,
*
args
,
**
kws
)
# Call the original __init__
.. code-block:: python
self.op = blackbox(MyCustomOp, hidden_units=128)
"""
# get caller module name
frm
=
inspect
.
stack
()[
1
]
module_name
=
inspect
.
getmodule
(
frm
[
0
]).
__name__
return
_blackbox_cls
(
cls
,
module_name
,
'args'
)(
*
args
,
**
kwargs
)
original_class
.
__init__
=
__init__
# Set the class' __init__ to the new one
return
original_class
def
blackbox_module
(
cls
):
"""
Register a module. Use it as a decorator.
"""
frm
=
inspect
.
stack
()[
1
]
module_name
=
inspect
.
getmodule
(
frm
[
0
]).
__name__
return
_blackbox_cls
(
cls
,
module_name
,
'args'
)
def
register_trainer
():
def
_register
(
cls
):
m
=
_register_trainer
(
original_class
=
cls
)
return
m
return
_register
def
register_trainer
(
cls
):
"""
Register a trainer. Use it as a decorator.
"""
frm
=
inspect
.
stack
()[
1
]
module_name
=
inspect
.
getmodule
(
frm
[
0
]).
__name__
return
_blackbox_cls
(
cls
,
module_name
,
'full'
)
_last_uid
=
defaultdict
(
int
)
_last_uid
=
defaultdict
(
int
)
...
...
nni/runtime/log.py
View file @
349ead41
...
@@ -12,6 +12,13 @@ import colorama
...
@@ -12,6 +12,13 @@ import colorama
from
.env_vars
import
dispatcher_env_vars
,
trial_env_vars
from
.env_vars
import
dispatcher_env_vars
,
trial_env_vars
handlers
=
{}
log_format
=
'[%(asctime)s] %(levelname)s (%(name)s/%(threadName)s) %(message)s'
time_format
=
'%Y-%m-%d %H:%M:%S'
formatter
=
Formatter
(
log_format
,
time_format
)
def
init_logger
()
->
None
:
def
init_logger
()
->
None
:
"""
"""
This function will (and should only) get invoked on the first time of importing nni (no matter which submodule).
This function will (and should only) get invoked on the first time of importing nni (no matter which submodule).
...
@@ -37,6 +44,8 @@ def init_logger() -> None:
...
@@ -37,6 +44,8 @@ def init_logger() -> None:
_init_logger_standalone
()
_init_logger_standalone
()
logging
.
getLogger
(
'filelock'
).
setLevel
(
logging
.
WARNING
)
def
init_logger_experiment
()
->
None
:
def
init_logger_experiment
()
->
None
:
"""
"""
...
@@ -44,15 +53,19 @@ def init_logger_experiment() -> None:
...
@@ -44,15 +53,19 @@ def init_logger_experiment() -> None:
This function will get invoked after `init_logger()`.
This function will get invoked after `init_logger()`.
"""
"""
formatter
.
format
=
_colorful_format
colorful_formatter
=
Formatter
(
log_format
,
time_format
)
colorful_formatter
.
format
=
_colorful_format
handlers
[
'_default_'
].
setFormatter
(
colorful_formatter
)
def
start_experiment_log
(
experiment_id
:
str
,
log_directory
:
Path
,
debug
:
bool
)
->
None
:
log_path
=
_prepare_log_dir
(
log_directory
)
/
'dispatcher.log'
log_level
=
logging
.
DEBUG
if
debug
else
logging
.
INFO
_register_handler
(
FileHandler
(
log_path
),
log_level
,
experiment_id
)
time_format
=
'%Y-%m-%d %H:%M:%S'
def
stop_experiment_log
(
experiment_id
:
str
)
->
None
:
if
experiment_id
in
handlers
:
logging
.
getLogger
().
removeHandler
(
handlers
.
pop
(
experiment_id
))
formatter
=
Formatter
(
'[%(asctime)s] %(levelname)s (%(name)s/%(threadName)s) %(message)s'
,
time_format
)
def
_init_logger_dispatcher
()
->
None
:
def
_init_logger_dispatcher
()
->
None
:
log_level_map
=
{
log_level_map
=
{
...
@@ -66,26 +79,20 @@ def _init_logger_dispatcher() -> None:
...
@@ -66,26 +79,20 @@ def _init_logger_dispatcher() -> None:
log_path
=
_prepare_log_dir
(
dispatcher_env_vars
.
NNI_LOG_DIRECTORY
)
/
'dispatcher.log'
log_path
=
_prepare_log_dir
(
dispatcher_env_vars
.
NNI_LOG_DIRECTORY
)
/
'dispatcher.log'
log_level
=
log_level_map
.
get
(
dispatcher_env_vars
.
NNI_LOG_LEVEL
,
logging
.
INFO
)
log_level
=
log_level_map
.
get
(
dispatcher_env_vars
.
NNI_LOG_LEVEL
,
logging
.
INFO
)
_
setup_root_logg
er
(
FileHandler
(
log_path
),
log_level
)
_
register_handl
er
(
FileHandler
(
log_path
),
log_level
)
def
_init_logger_trial
()
->
None
:
def
_init_logger_trial
()
->
None
:
log_path
=
_prepare_log_dir
(
trial_env_vars
.
NNI_OUTPUT_DIR
)
/
'trial.log'
log_path
=
_prepare_log_dir
(
trial_env_vars
.
NNI_OUTPUT_DIR
)
/
'trial.log'
log_file
=
open
(
log_path
,
'w'
)
log_file
=
open
(
log_path
,
'w'
)
_
setup_root_logg
er
(
StreamHandler
(
log_file
),
logging
.
INFO
)
_
register_handl
er
(
StreamHandler
(
log_file
),
logging
.
INFO
)
if
trial_env_vars
.
NNI_PLATFORM
==
'local'
:
if
trial_env_vars
.
NNI_PLATFORM
==
'local'
:
sys
.
stdout
=
_LogFileWrapper
(
log_file
)
sys
.
stdout
=
_LogFileWrapper
(
log_file
)
def
_init_logger_standalone
()
->
None
:
def
_init_logger_standalone
()
->
None
:
_setup_nni_logger
(
StreamHandler
(
sys
.
stdout
),
logging
.
INFO
)
_register_handler
(
StreamHandler
(
sys
.
stdout
),
logging
.
INFO
)
# Following line does not affect NNI loggers, but without this user's logger won't
# print log even it's level is set to INFO, so we do it for user's convenience.
# If this causes any issue in future, remove it and use `logging.info()` instead of
# `logging.getLogger('xxx').info()` in all examples.
logging
.
basicConfig
()
def
_prepare_log_dir
(
path
:
Optional
[
str
])
->
Path
:
def
_prepare_log_dir
(
path
:
Optional
[
str
])
->
Path
:
...
@@ -95,20 +102,18 @@ def _prepare_log_dir(path: Optional[str]) -> Path:
...
@@ -95,20 +102,18 @@ def _prepare_log_dir(path: Optional[str]) -> Path:
ret
.
mkdir
(
parents
=
True
,
exist_ok
=
True
)
ret
.
mkdir
(
parents
=
True
,
exist_ok
=
True
)
return
ret
return
ret
def
_setup_root_logger
(
handler
:
Handler
,
level
:
int
)
->
None
:
def
_register_handler
(
handler
:
Handler
,
level
:
int
,
tag
:
str
=
'_default_'
)
->
None
:
_setup_logger
(
''
,
handler
,
level
)
assert
tag
not
in
handlers
handlers
[
tag
]
=
handler
def
_setup_nni_logger
(
handler
:
Handler
,
level
:
int
)
->
None
:
_setup_logger
(
'nni'
,
handler
,
level
)
def
_setup_logger
(
name
:
str
,
handler
:
Handler
,
level
:
int
)
->
None
:
handler
.
setFormatter
(
formatter
)
handler
.
setFormatter
(
formatter
)
logger
=
logging
.
getLogger
(
name
)
logger
=
logging
.
getLogger
()
logger
.
addHandler
(
handler
)
logger
.
addHandler
(
handler
)
logger
.
setLevel
(
level
)
logger
.
setLevel
(
level
)
logger
.
propagate
=
False
def
_colorful_format
(
record
):
def
_colorful_format
(
record
):
time
=
formatter
.
formatTime
(
record
,
time_format
)
if
not
record
.
name
.
startswith
(
'nni.'
):
return
'[{}] ({}) {}'
.
format
(
time
,
record
.
name
,
record
.
msg
%
record
.
args
)
if
record
.
levelno
>=
logging
.
ERROR
:
if
record
.
levelno
>=
logging
.
ERROR
:
color
=
colorama
.
Fore
.
RED
color
=
colorama
.
Fore
.
RED
elif
record
.
levelno
>=
logging
.
WARNING
:
elif
record
.
levelno
>=
logging
.
WARNING
:
...
@@ -118,7 +123,6 @@ def _colorful_format(record):
...
@@ -118,7 +123,6 @@ def _colorful_format(record):
else
:
else
:
color
=
colorama
.
Fore
.
BLUE
color
=
colorama
.
Fore
.
BLUE
msg
=
color
+
(
record
.
msg
%
record
.
args
)
+
colorama
.
Style
.
RESET_ALL
msg
=
color
+
(
record
.
msg
%
record
.
args
)
+
colorama
.
Style
.
RESET_ALL
time
=
formatter
.
formatTime
(
record
,
time_format
)
if
record
.
levelno
<
logging
.
INFO
:
if
record
.
levelno
<
logging
.
INFO
:
return
'[{}] {}:{} {}'
.
format
(
time
,
record
.
threadName
,
record
.
name
,
msg
)
return
'[{}] {}:{} {}'
.
format
(
time
,
record
.
threadName
,
record
.
name
,
msg
)
else
:
else
:
...
...
nni/runtime/platform/__init__.py
View file @
349ead41
...
@@ -9,7 +9,7 @@ if trial_env_vars.NNI_PLATFORM is None:
...
@@ -9,7 +9,7 @@ if trial_env_vars.NNI_PLATFORM is None:
from
.standalone
import
*
from
.standalone
import
*
elif
trial_env_vars
.
NNI_PLATFORM
==
'unittest'
:
elif
trial_env_vars
.
NNI_PLATFORM
==
'unittest'
:
from
.test
import
*
from
.test
import
*
elif
trial_env_vars
.
NNI_PLATFORM
in
(
'local'
,
'remote'
,
'pai'
,
'kubeflow'
,
'frameworkcontroller'
,
'paiYarn'
,
'dlts'
,
'aml'
,
'adl'
,
'h
eterogeneous
'
):
elif
trial_env_vars
.
NNI_PLATFORM
in
(
'local'
,
'remote'
,
'pai'
,
'kubeflow'
,
'frameworkcontroller'
,
'paiYarn'
,
'dlts'
,
'aml'
,
'adl'
,
'h
ybrid
'
):
from
.local
import
*
from
.local
import
*
else
:
else
:
raise
RuntimeError
(
'Unknown platform %s'
%
trial_env_vars
.
NNI_PLATFORM
)
raise
RuntimeError
(
'Unknown platform %s'
%
trial_env_vars
.
NNI_PLATFORM
)
nni/tools/nnictl/config_schema.py
View file @
349ead41
...
@@ -124,7 +124,7 @@ common_schema = {
...
@@ -124,7 +124,7 @@ common_schema = {
Optional
(
'maxExecDuration'
):
And
(
Regex
(
r
'^[1-9][0-9]*[s|m|h|d]$'
,
error
=
'ERROR: maxExecDuration format is [digit]{s,m,h,d}'
)),
Optional
(
'maxExecDuration'
):
And
(
Regex
(
r
'^[1-9][0-9]*[s|m|h|d]$'
,
error
=
'ERROR: maxExecDuration format is [digit]{s,m,h,d}'
)),
Optional
(
'maxTrialNum'
):
setNumberRange
(
'maxTrialNum'
,
int
,
1
,
99999
),
Optional
(
'maxTrialNum'
):
setNumberRange
(
'maxTrialNum'
,
int
,
1
,
99999
),
'trainingServicePlatform'
:
setChoice
(
'trainingServicePlatform'
:
setChoice
(
'trainingServicePlatform'
,
'remote'
,
'local'
,
'pai'
,
'kubeflow'
,
'frameworkcontroller'
,
'paiYarn'
,
'dlts'
,
'aml'
,
'adl'
,
'h
eterogeneous
'
),
'trainingServicePlatform'
,
'remote'
,
'local'
,
'pai'
,
'kubeflow'
,
'frameworkcontroller'
,
'paiYarn'
,
'dlts'
,
'aml'
,
'adl'
,
'h
ybrid
'
),
Optional
(
'searchSpacePath'
):
And
(
os
.
path
.
exists
,
error
=
SCHEMA_PATH_ERROR
%
'searchSpacePath'
),
Optional
(
'searchSpacePath'
):
And
(
os
.
path
.
exists
,
error
=
SCHEMA_PATH_ERROR
%
'searchSpacePath'
),
Optional
(
'multiPhase'
):
setType
(
'multiPhase'
,
bool
),
Optional
(
'multiPhase'
):
setType
(
'multiPhase'
,
bool
),
Optional
(
'multiThread'
):
setType
(
'multiThread'
,
bool
),
Optional
(
'multiThread'
):
setType
(
'multiThread'
,
bool
),
...
@@ -262,7 +262,7 @@ aml_config_schema = {
...
@@ -262,7 +262,7 @@ aml_config_schema = {
}
}
}
}
h
eterogeneous
_trial_schema
=
{
h
ybrid
_trial_schema
=
{
'trial'
:
{
'trial'
:
{
'codeDir'
:
setPathCheck
(
'codeDir'
),
'codeDir'
:
setPathCheck
(
'codeDir'
),
Optional
(
'nniManagerNFSMountPath'
):
setPathCheck
(
'nniManagerNFSMountPath'
),
Optional
(
'nniManagerNFSMountPath'
):
setPathCheck
(
'nniManagerNFSMountPath'
),
...
@@ -279,8 +279,8 @@ heterogeneous_trial_schema = {
...
@@ -279,8 +279,8 @@ heterogeneous_trial_schema = {
}
}
}
}
h
eterogeneous
_config_schema
=
{
h
ybrid
_config_schema
=
{
'h
eterogeneous
Config'
:
{
'h
ybrid
Config'
:
{
'trainingServicePlatforms'
:
[
'local'
,
'remote'
,
'pai'
,
'aml'
]
'trainingServicePlatforms'
:
[
'local'
,
'remote'
,
'pai'
,
'aml'
]
}
}
}
}
...
@@ -461,7 +461,7 @@ training_service_schema_dict = {
...
@@ -461,7 +461,7 @@ training_service_schema_dict = {
'frameworkcontroller'
:
Schema
({
**
common_schema
,
**
frameworkcontroller_trial_schema
,
**
frameworkcontroller_config_schema
}),
'frameworkcontroller'
:
Schema
({
**
common_schema
,
**
frameworkcontroller_trial_schema
,
**
frameworkcontroller_config_schema
}),
'aml'
:
Schema
({
**
common_schema
,
**
aml_trial_schema
,
**
aml_config_schema
}),
'aml'
:
Schema
({
**
common_schema
,
**
aml_trial_schema
,
**
aml_config_schema
}),
'dlts'
:
Schema
({
**
common_schema
,
**
dlts_trial_schema
,
**
dlts_config_schema
}),
'dlts'
:
Schema
({
**
common_schema
,
**
dlts_trial_schema
,
**
dlts_config_schema
}),
'h
eterogeneous
'
:
Schema
({
**
common_schema
,
**
h
eterogeneous
_trial_schema
,
**
h
eterogeneous
_config_schema
,
**
machine_list_schema
,
'h
ybrid
'
:
Schema
({
**
common_schema
,
**
h
ybrid
_trial_schema
,
**
h
ybrid
_config_schema
,
**
machine_list_schema
,
**
pai_config_schema
,
**
aml_config_schema
,
**
remote_config_schema
}),
**
pai_config_schema
,
**
aml_config_schema
,
**
remote_config_schema
}),
}
}
...
@@ -479,7 +479,7 @@ class NNIConfigSchema:
...
@@ -479,7 +479,7 @@ class NNIConfigSchema:
self
.
validate_pai_trial_conifg
(
experiment_config
)
self
.
validate_pai_trial_conifg
(
experiment_config
)
self
.
validate_kubeflow_operators
(
experiment_config
)
self
.
validate_kubeflow_operators
(
experiment_config
)
self
.
validate_eth0_device
(
experiment_config
)
self
.
validate_eth0_device
(
experiment_config
)
self
.
validate_h
eterogeneous
_platforms
(
experiment_config
)
self
.
validate_h
ybrid
_platforms
(
experiment_config
)
def
validate_tuner_adivosr_assessor
(
self
,
experiment_config
):
def
validate_tuner_adivosr_assessor
(
self
,
experiment_config
):
if
experiment_config
.
get
(
'advisor'
):
if
experiment_config
.
get
(
'advisor'
):
...
@@ -590,15 +590,15 @@ class NNIConfigSchema:
...
@@ -590,15 +590,15 @@ class NNIConfigSchema:
and
'eth0'
not
in
netifaces
.
interfaces
():
and
'eth0'
not
in
netifaces
.
interfaces
():
raise
SchemaError
(
'This machine does not contain eth0 network device, please set nniManagerIp in config file!'
)
raise
SchemaError
(
'This machine does not contain eth0 network device, please set nniManagerIp in config file!'
)
def
validate_h
eterogeneous
_platforms
(
self
,
experiment_config
):
def
validate_h
ybrid
_platforms
(
self
,
experiment_config
):
required_config_name_map
=
{
required_config_name_map
=
{
'remote'
:
'machineList'
,
'remote'
:
'machineList'
,
'aml'
:
'amlConfig'
,
'aml'
:
'amlConfig'
,
'pai'
:
'paiConfig'
'pai'
:
'paiConfig'
}
}
if
experiment_config
.
get
(
'trainingServicePlatform'
)
==
'h
eterogeneous
'
:
if
experiment_config
.
get
(
'trainingServicePlatform'
)
==
'h
ybrid
'
:
for
platform
in
experiment_config
[
'h
eterogeneous
Config'
][
'trainingServicePlatforms'
]:
for
platform
in
experiment_config
[
'h
ybrid
Config'
][
'trainingServicePlatforms'
]:
config_name
=
required_config_name_map
.
get
(
platform
)
config_name
=
required_config_name_map
.
get
(
platform
)
if
config_name
and
not
experiment_config
.
get
(
config_name
):
if
config_name
and
not
experiment_config
.
get
(
config_name
):
raise
SchemaError
(
'Need to set {0} for {1} in h
eterogeneous
mode!'
.
format
(
config_name
,
platform
))
raise
SchemaError
(
'Need to set {0} for {1} in h
ybrid
mode!'
.
format
(
config_name
,
platform
))
\ No newline at end of file
nni/tools/nnictl/launcher.py
View file @
349ead41
...
@@ -17,7 +17,7 @@ from .launcher_utils import validate_all_content
...
@@ -17,7 +17,7 @@ from .launcher_utils import validate_all_content
from
.rest_utils
import
rest_put
,
rest_post
,
check_rest_server
,
check_response
from
.rest_utils
import
rest_put
,
rest_post
,
check_rest_server
,
check_response
from
.url_utils
import
cluster_metadata_url
,
experiment_url
,
get_local_urls
from
.url_utils
import
cluster_metadata_url
,
experiment_url
,
get_local_urls
from
.config_utils
import
Config
,
Experiments
from
.config_utils
import
Config
,
Experiments
from
.common_utils
import
get_yml_content
,
get_json_content
,
print_error
,
print_normal
,
\
from
.common_utils
import
get_yml_content
,
get_json_content
,
print_error
,
print_normal
,
print_warning
,
\
detect_port
,
get_user
detect_port
,
get_user
from
.constants
import
NNICTL_HOME_DIR
,
ERROR_INFO
,
REST_TIME_OUT
,
EXPERIMENT_SUCCESS_INFO
,
LOG_HEADER
from
.constants
import
NNICTL_HOME_DIR
,
ERROR_INFO
,
REST_TIME_OUT
,
EXPERIMENT_SUCCESS_INFO
,
LOG_HEADER
...
@@ -47,10 +47,10 @@ def start_rest_server(port, platform, mode, experiment_id, foreground=False, log
...
@@ -47,10 +47,10 @@ def start_rest_server(port, platform, mode, experiment_id, foreground=False, log
'You could use
\'
nnictl create --help
\'
to get help information'
%
port
)
'You could use
\'
nnictl create --help
\'
to get help information'
%
port
)
exit
(
1
)
exit
(
1
)
if
(
platform
!=
'local'
)
and
detect_port
(
int
(
port
)
+
1
):
if
(
platform
not
in
[
'local'
,
'aml'
]
)
and
detect_port
(
int
(
port
)
+
1
):
print_error
(
'
PAI
mode need an additional adjacent port %d, and the port %d is used by another process!
\n
'
\
print_error
(
'
%s
mode need an additional adjacent port %d, and the port %d is used by another process!
\n
'
\
'You could set another port to start experiment!
\n
'
\
'You could set another port to start experiment!
\n
'
\
'You could use
\'
nnictl create --help
\'
to get help information'
%
((
int
(
port
)
+
1
),
(
int
(
port
)
+
1
)))
'You could use
\'
nnictl create --help
\'
to get help information'
%
(
platform
,
(
int
(
port
)
+
1
),
(
int
(
port
)
+
1
)))
exit
(
1
)
exit
(
1
)
print_normal
(
'Starting restful server...'
)
print_normal
(
'Starting restful server...'
)
...
@@ -300,23 +300,25 @@ def set_aml_config(experiment_config, port, config_file_name):
...
@@ -300,23 +300,25 @@ def set_aml_config(experiment_config, port, config_file_name):
#set trial_config
#set trial_config
return
set_trial_config
(
experiment_config
,
port
,
config_file_name
),
err_message
return
set_trial_config
(
experiment_config
,
port
,
config_file_name
),
err_message
def
set_h
eterogeneous
_config
(
experiment_config
,
port
,
config_file_name
):
def
set_h
ybrid
_config
(
experiment_config
,
port
,
config_file_name
):
'''set h
eterogeneous
configuration'''
'''set h
ybrid
configuration'''
h
eterogeneous
_config_data
=
dict
()
h
ybrid
_config_data
=
dict
()
h
eterogeneous
_config_data
[
'h
eterogeneous
_config'
]
=
experiment_config
[
'h
eterogeneous
Config'
]
h
ybrid
_config_data
[
'h
ybrid
_config'
]
=
experiment_config
[
'h
ybrid
Config'
]
platform_list
=
experiment_config
[
'h
eterogeneous
Config'
][
'trainingServicePlatforms'
]
platform_list
=
experiment_config
[
'h
ybrid
Config'
][
'trainingServicePlatforms'
]
for
platform
in
platform_list
:
for
platform
in
platform_list
:
if
platform
==
'aml'
:
if
platform
==
'aml'
:
h
eterogeneous
_config_data
[
'aml_config'
]
=
experiment_config
[
'amlConfig'
]
h
ybrid
_config_data
[
'aml_config'
]
=
experiment_config
[
'amlConfig'
]
elif
platform
==
'remote'
:
elif
platform
==
'remote'
:
if
experiment_config
.
get
(
'remoteConfig'
):
if
experiment_config
.
get
(
'remoteConfig'
):
h
eterogeneous
_config_data
[
'remote_config'
]
=
experiment_config
[
'remoteConfig'
]
h
ybrid
_config_data
[
'remote_config'
]
=
experiment_config
[
'remoteConfig'
]
h
eterogeneous
_config_data
[
'machine_list'
]
=
experiment_config
[
'machineList'
]
h
ybrid
_config_data
[
'machine_list'
]
=
experiment_config
[
'machineList'
]
elif
platform
==
'local'
and
experiment_config
.
get
(
'localConfig'
):
elif
platform
==
'local'
and
experiment_config
.
get
(
'localConfig'
):
h
eterogeneous
_config_data
[
'local_config'
]
=
experiment_config
[
'localConfig'
]
h
ybrid
_config_data
[
'local_config'
]
=
experiment_config
[
'localConfig'
]
elif
platform
==
'pai'
:
elif
platform
==
'pai'
:
heterogeneous_config_data
[
'pai_config'
]
=
experiment_config
[
'paiConfig'
]
hybrid_config_data
[
'pai_config'
]
=
experiment_config
[
'paiConfig'
]
response
=
rest_put
(
cluster_metadata_url
(
port
),
json
.
dumps
(
heterogeneous_config_data
),
REST_TIME_OUT
)
# It needs to connect all remote machines, set longer timeout here to wait for restful server connection response.
time_out
=
60
if
'remote'
in
platform_list
else
REST_TIME_OUT
response
=
rest_put
(
cluster_metadata_url
(
port
),
json
.
dumps
(
hybrid_config_data
),
time_out
)
err_message
=
None
err_message
=
None
if
not
response
or
not
response
.
status_code
==
200
:
if
not
response
or
not
response
.
status_code
==
200
:
if
response
is
not
None
:
if
response
is
not
None
:
...
@@ -412,10 +414,10 @@ def set_experiment(experiment_config, mode, port, config_file_name):
...
@@ -412,10 +414,10 @@ def set_experiment(experiment_config, mode, port, config_file_name):
{
'key'
:
'aml_config'
,
'value'
:
experiment_config
[
'amlConfig'
]})
{
'key'
:
'aml_config'
,
'value'
:
experiment_config
[
'amlConfig'
]})
request_data
[
'clusterMetaData'
].
append
(
request_data
[
'clusterMetaData'
].
append
(
{
'key'
:
'trial_config'
,
'value'
:
experiment_config
[
'trial'
]})
{
'key'
:
'trial_config'
,
'value'
:
experiment_config
[
'trial'
]})
elif
experiment_config
[
'trainingServicePlatform'
]
==
'h
eterogeneous
'
:
elif
experiment_config
[
'trainingServicePlatform'
]
==
'h
ybrid
'
:
request_data
[
'clusterMetaData'
].
append
(
request_data
[
'clusterMetaData'
].
append
(
{
'key'
:
'h
eterogeneous
_config'
,
'value'
:
experiment_config
[
'h
eterogeneous
Config'
]})
{
'key'
:
'h
ybrid
_config'
,
'value'
:
experiment_config
[
'h
ybrid
Config'
]})
platform_list
=
experiment_config
[
'h
eterogeneous
Config'
][
'trainingServicePlatforms'
]
platform_list
=
experiment_config
[
'h
ybrid
Config'
][
'trainingServicePlatforms'
]
request_dict
=
{
request_dict
=
{
'aml'
:
{
'key'
:
'aml_config'
,
'value'
:
experiment_config
.
get
(
'amlConfig'
)},
'aml'
:
{
'key'
:
'aml_config'
,
'value'
:
experiment_config
.
get
(
'amlConfig'
)},
'remote'
:
{
'key'
:
'machine_list'
,
'value'
:
experiment_config
.
get
(
'machineList'
)},
'remote'
:
{
'key'
:
'machine_list'
,
'value'
:
experiment_config
.
get
(
'machineList'
)},
...
@@ -460,8 +462,8 @@ def set_platform_config(platform, experiment_config, port, config_file_name, res
...
@@ -460,8 +462,8 @@ def set_platform_config(platform, experiment_config, port, config_file_name, res
config_result
,
err_msg
=
set_dlts_config
(
experiment_config
,
port
,
config_file_name
)
config_result
,
err_msg
=
set_dlts_config
(
experiment_config
,
port
,
config_file_name
)
elif
platform
==
'aml'
:
elif
platform
==
'aml'
:
config_result
,
err_msg
=
set_aml_config
(
experiment_config
,
port
,
config_file_name
)
config_result
,
err_msg
=
set_aml_config
(
experiment_config
,
port
,
config_file_name
)
elif
platform
==
'h
eterogeneous
'
:
elif
platform
==
'h
ybrid
'
:
config_result
,
err_msg
=
set_h
eterogeneous
_config
(
experiment_config
,
port
,
config_file_name
)
config_result
,
err_msg
=
set_h
ybrid
_config
(
experiment_config
,
port
,
config_file_name
)
else
:
else
:
raise
Exception
(
ERROR_INFO
%
'Unsupported platform!'
)
raise
Exception
(
ERROR_INFO
%
'Unsupported platform!'
)
exit
(
1
)
exit
(
1
)
...
@@ -509,6 +511,11 @@ def launch_experiment(args, experiment_config, mode, experiment_id):
...
@@ -509,6 +511,11 @@ def launch_experiment(args, experiment_config, mode, experiment_id):
rest_process
,
start_time
=
start_rest_server
(
args
.
port
,
experiment_config
[
'trainingServicePlatform'
],
\
rest_process
,
start_time
=
start_rest_server
(
args
.
port
,
experiment_config
[
'trainingServicePlatform'
],
\
mode
,
experiment_id
,
foreground
,
log_dir
,
log_level
)
mode
,
experiment_id
,
foreground
,
log_dir
,
log_level
)
nni_config
.
set_config
(
'restServerPid'
,
rest_process
.
pid
)
nni_config
.
set_config
(
'restServerPid'
,
rest_process
.
pid
)
# save experiment information
nnictl_experiment_config
=
Experiments
()
nnictl_experiment_config
.
add_experiment
(
experiment_id
,
args
.
port
,
start_time
,
experiment_config
[
'trainingServicePlatform'
],
experiment_config
[
'experimentName'
],
pid
=
rest_process
.
pid
,
logDir
=
log_dir
)
# Deal with annotation
# Deal with annotation
if
experiment_config
.
get
(
'useAnnotation'
):
if
experiment_config
.
get
(
'useAnnotation'
):
path
=
os
.
path
.
join
(
tempfile
.
gettempdir
(),
get_user
(),
'nni'
,
'annotation'
)
path
=
os
.
path
.
join
(
tempfile
.
gettempdir
(),
get_user
(),
'nni'
,
'annotation'
)
...
@@ -546,11 +553,6 @@ def launch_experiment(args, experiment_config, mode, experiment_id):
...
@@ -546,11 +553,6 @@ def launch_experiment(args, experiment_config, mode, experiment_id):
# start a new experiment
# start a new experiment
print_normal
(
'Starting experiment...'
)
print_normal
(
'Starting experiment...'
)
# save experiment information
nnictl_experiment_config
=
Experiments
()
nnictl_experiment_config
.
add_experiment
(
experiment_id
,
args
.
port
,
start_time
,
experiment_config
[
'trainingServicePlatform'
],
experiment_config
[
'experimentName'
],
pid
=
rest_process
.
pid
,
logDir
=
log_dir
)
# set debug configuration
# set debug configuration
if
mode
!=
'view'
and
experiment_config
.
get
(
'debug'
)
is
None
:
if
mode
!=
'view'
and
experiment_config
.
get
(
'debug'
)
is
None
:
experiment_config
[
'debug'
]
=
args
.
debug
experiment_config
[
'debug'
]
=
args
.
debug
...
@@ -567,7 +569,7 @@ def launch_experiment(args, experiment_config, mode, experiment_id):
...
@@ -567,7 +569,7 @@ def launch_experiment(args, experiment_config, mode, experiment_id):
raise
Exception
(
ERROR_INFO
%
'Restful server stopped!'
)
raise
Exception
(
ERROR_INFO
%
'Restful server stopped!'
)
exit
(
1
)
exit
(
1
)
if
experiment_config
.
get
(
'nniManagerIp'
):
if
experiment_config
.
get
(
'nniManagerIp'
):
web_ui_url_list
=
[
'{0}:{1}'
.
format
(
experiment_config
[
'nniManagerIp'
],
str
(
args
.
port
))]
web_ui_url_list
=
[
'
http://
{0}:{1}'
.
format
(
experiment_config
[
'nniManagerIp'
],
str
(
args
.
port
))]
else
:
else
:
web_ui_url_list
=
get_local_urls
(
args
.
port
)
web_ui_url_list
=
get_local_urls
(
args
.
port
)
nni_config
.
set_config
(
'webuiUrl'
,
web_ui_url_list
)
nni_config
.
set_config
(
'webuiUrl'
,
web_ui_url_list
)
...
@@ -592,24 +594,28 @@ def create_experiment(args):
...
@@ -592,24 +594,28 @@ def create_experiment(args):
print_error
(
'Please set correct config path!'
)
print_error
(
'Please set correct config path!'
)
exit
(
1
)
exit
(
1
)
experiment_config
=
get_yml_content
(
config_path
)
experiment_config
=
get_yml_content
(
config_path
)
try
:
config
=
ExperimentConfig
(
**
experiment_config
)
experiment_config
=
convert
.
to_v1_yaml
(
config
)
except
Exception
:
pass
try
:
try
:
validate_all_content
(
experiment_config
,
config_path
)
validate_all_content
(
experiment_config
,
config_path
)
except
Exception
as
e
:
except
Exception
:
print_error
(
e
)
print_warning
(
'Validation with V1 schema failed. Trying to convert from V2 format...'
)
exit
(
1
)
try
:
config
=
ExperimentConfig
(
**
experiment_config
)
experiment_config
=
convert
.
to_v1_yaml
(
config
)
except
Exception
as
e
:
print_error
(
f
'Conversion from v2 format failed:
{
repr
(
e
)
}
'
)
try
:
validate_all_content
(
experiment_config
,
config_path
)
except
Exception
as
e
:
print_error
(
f
'Config in v1 format validation failed.
{
repr
(
e
)
}
'
)
exit
(
1
)
nni_config
.
set_config
(
'experimentConfig'
,
experiment_config
)
nni_config
.
set_config
(
'experimentConfig'
,
experiment_config
)
nni_config
.
set_config
(
'restServerPort'
,
args
.
port
)
nni_config
.
set_config
(
'restServerPort'
,
args
.
port
)
try
:
try
:
launch_experiment
(
args
,
experiment_config
,
'new'
,
experiment_id
)
launch_experiment
(
args
,
experiment_config
,
'new'
,
experiment_id
)
except
Exception
as
exception
:
except
Exception
as
exception
:
nni_config
=
Config
(
experiment_id
)
restServerPid
=
Experiments
().
get_all_experiments
().
get
(
experiment_id
,
{}).
get
(
'pid'
)
restServerPid
=
nni_config
.
get_config
(
'restServerPid'
)
if
restServerPid
:
if
restServerPid
:
kill_command
(
restServerPid
)
kill_command
(
restServerPid
)
print_error
(
exception
)
print_error
(
exception
)
...
@@ -641,8 +647,7 @@ def manage_stopped_experiment(args, mode):
...
@@ -641,8 +647,7 @@ def manage_stopped_experiment(args, mode):
try
:
try
:
launch_experiment
(
args
,
experiment_config
,
mode
,
experiment_id
)
launch_experiment
(
args
,
experiment_config
,
mode
,
experiment_id
)
except
Exception
as
exception
:
except
Exception
as
exception
:
nni_config
=
Config
(
experiment_id
)
restServerPid
=
Experiments
().
get_all_experiments
().
get
(
experiment_id
,
{}).
get
(
'pid'
)
restServerPid
=
nni_config
.
get_config
(
'restServerPid'
)
if
restServerPid
:
if
restServerPid
:
kill_command
(
restServerPid
)
kill_command
(
restServerPid
)
print_error
(
exception
)
print_error
(
exception
)
...
...
nni/tools/nnictl/launcher_utils.py
View file @
349ead41
...
@@ -105,7 +105,9 @@ def set_default_values(experiment_config):
...
@@ -105,7 +105,9 @@ def set_default_values(experiment_config):
experiment_config
[
'maxExecDuration'
]
=
'999d'
experiment_config
[
'maxExecDuration'
]
=
'999d'
if
experiment_config
.
get
(
'maxTrialNum'
)
is
None
:
if
experiment_config
.
get
(
'maxTrialNum'
)
is
None
:
experiment_config
[
'maxTrialNum'
]
=
99999
experiment_config
[
'maxTrialNum'
]
=
99999
if
experiment_config
[
'trainingServicePlatform'
]
==
'remote'
:
if
experiment_config
[
'trainingServicePlatform'
]
==
'remote'
or
\
experiment_config
[
'trainingServicePlatform'
]
==
'hybrid'
and
\
'remote'
in
experiment_config
[
'hybridConfig'
][
'trainingServicePlatforms'
]:
for
index
in
range
(
len
(
experiment_config
[
'machineList'
])):
for
index
in
range
(
len
(
experiment_config
[
'machineList'
])):
if
experiment_config
[
'machineList'
][
index
].
get
(
'port'
)
is
None
:
if
experiment_config
[
'machineList'
][
index
].
get
(
'port'
)
is
None
:
experiment_config
[
'machineList'
][
index
][
'port'
]
=
22
experiment_config
[
'machineList'
][
index
][
'port'
]
=
22
...
...
nni/tools/nnictl/nnictl.py
View file @
349ead41
...
@@ -2,6 +2,7 @@
...
@@ -2,6 +2,7 @@
# Licensed under the MIT license.
# Licensed under the MIT license.
import
argparse
import
argparse
import
logging
import
os
import
os
import
pkg_resources
import
pkg_resources
from
colorama
import
init
from
colorama
import
init
...
@@ -32,6 +33,8 @@ def nni_info(*args):
...
@@ -32,6 +33,8 @@ def nni_info(*args):
print
(
'please run "nnictl {positional argument} --help" to see nnictl guidance'
)
print
(
'please run "nnictl {positional argument} --help" to see nnictl guidance'
)
def
parse_args
():
def
parse_args
():
logging
.
getLogger
().
setLevel
(
logging
.
ERROR
)
'''Definite the arguments users need to follow and input'''
'''Definite the arguments users need to follow and input'''
parser
=
argparse
.
ArgumentParser
(
prog
=
'nnictl'
,
description
=
'use nnictl command to control nni experiments'
)
parser
=
argparse
.
ArgumentParser
(
prog
=
'nnictl'
,
description
=
'use nnictl command to control nni experiments'
)
parser
.
add_argument
(
'--version'
,
'-v'
,
action
=
'store_true'
)
parser
.
add_argument
(
'--version'
,
'-v'
,
action
=
'store_true'
)
...
@@ -243,12 +246,9 @@ def parse_args():
...
@@ -243,12 +246,9 @@ def parse_args():
def
show_messsage_for_nnictl_package
(
args
):
def
show_messsage_for_nnictl_package
(
args
):
print_error
(
'nnictl package command is replaced by nnictl algo, please run nnictl algo -h to show the usage'
)
print_error
(
'nnictl package command is replaced by nnictl algo, please run nnictl algo -h to show the usage'
)
parser_package_subparsers
=
subparsers
.
add_parser
(
'package'
,
help
=
'control nni tuner and assessor packages'
).
add_subparsers
()
parser_package_subparsers
=
subparsers
.
add_parser
(
'package'
,
help
=
'this argument is replaced by algo'
,
prefix_chars
=
'
\n
'
)
parser_package_subparsers
.
add_parser
(
'install'
,
help
=
'install packages'
).
set_defaults
(
func
=
show_messsage_for_nnictl_package
)
parser_package_subparsers
.
add_argument
(
'args'
,
nargs
=
argparse
.
REMAINDER
)
parser_package_subparsers
.
add_parser
(
'uninstall'
,
help
=
'uninstall packages'
).
set_defaults
(
func
=
show_messsage_for_nnictl_package
)
parser_package_subparsers
.
set_defaults
(
func
=
show_messsage_for_nnictl_package
)
parser_package_subparsers
.
add_parser
(
'show'
,
help
=
'show the information of packages'
).
set_defaults
(
func
=
show_messsage_for_nnictl_package
)
parser_package_subparsers
.
add_parser
(
'list'
,
help
=
'list installed packages'
).
set_defaults
(
func
=
show_messsage_for_nnictl_package
)
#parse tensorboard command
#parse tensorboard command
parser_tensorboard
=
subparsers
.
add_parser
(
'tensorboard'
,
help
=
'manage tensorboard'
)
parser_tensorboard
=
subparsers
.
add_parser
(
'tensorboard'
,
help
=
'manage tensorboard'
)
...
...
nni/tools/nnictl/nnictl_utils.py
View file @
349ead41
...
@@ -50,11 +50,9 @@ def update_experiment():
...
@@ -50,11 +50,9 @@ def update_experiment():
for
key
in
experiment_dict
.
keys
():
for
key
in
experiment_dict
.
keys
():
if
isinstance
(
experiment_dict
[
key
],
dict
):
if
isinstance
(
experiment_dict
[
key
],
dict
):
if
experiment_dict
[
key
].
get
(
'status'
)
!=
'STOPPED'
:
if
experiment_dict
[
key
].
get
(
'status'
)
!=
'STOPPED'
:
nni_config
=
Config
(
key
)
rest_pid
=
experiment_dict
[
key
].
get
(
'pid'
)
rest_pid
=
nni_config
.
get_config
(
'restServerPid'
)
if
not
detect_process
(
rest_pid
):
if
not
detect_process
(
rest_pid
):
experiment_config
.
update_experiment
(
key
,
'status'
,
'STOPPED'
)
experiment_config
.
update_experiment
(
key
,
'status'
,
'STOPPED'
)
experiment_config
.
update_experiment
(
key
,
'port'
,
None
)
continue
continue
def
check_experiment_id
(
args
,
update
=
True
):
def
check_experiment_id
(
args
,
update
=
True
):
...
@@ -83,10 +81,10 @@ def check_experiment_id(args, update=True):
...
@@ -83,10 +81,10 @@ def check_experiment_id(args, update=True):
experiment_information
+=
EXPERIMENT_DETAIL_FORMAT
%
(
key
,
experiment_information
+=
EXPERIMENT_DETAIL_FORMAT
%
(
key
,
experiment_dict
[
key
].
get
(
'experimentName'
,
'N/A'
),
experiment_dict
[
key
].
get
(
'experimentName'
,
'N/A'
),
experiment_dict
[
key
][
'status'
],
experiment_dict
[
key
][
'status'
],
experiment_dict
[
key
]
[
'port'
]
,
experiment_dict
[
key
]
.
get
(
'port'
,
'N/A'
)
,
experiment_dict
[
key
].
get
(
'platform'
),
experiment_dict
[
key
].
get
(
'platform'
),
experiment_dict
[
key
][
'startTime'
],
time
.
strftime
(
'%Y-%m-%d %H:%M:%S'
,
time
.
localtime
(
experiment_dict
[
key
][
'startTime'
]
/
1000
))
if
isinstance
(
experiment_dict
[
key
][
'startTime'
],
int
)
else
experiment_dict
[
key
][
'startTime'
],
experiment_dict
[
key
][
'endTime'
])
time
.
strftime
(
'%Y-%m-%d %H:%M:%S'
,
time
.
localtime
(
experiment_dict
[
key
][
'endTime'
]
/
1000
))
if
isinstance
(
experiment_dict
[
key
][
'endTime'
],
int
)
else
experiment_dict
[
key
][
'endTime'
])
print
(
EXPERIMENT_INFORMATION_FORMAT
%
experiment_information
)
print
(
EXPERIMENT_INFORMATION_FORMAT
%
experiment_information
)
exit
(
1
)
exit
(
1
)
elif
not
running_experiment_list
:
elif
not
running_experiment_list
:
...
@@ -130,7 +128,7 @@ def parse_ids(args):
...
@@ -130,7 +128,7 @@ def parse_ids(args):
return
running_experiment_list
return
running_experiment_list
if
args
.
port
is
not
None
:
if
args
.
port
is
not
None
:
for
key
in
running_experiment_list
:
for
key
in
running_experiment_list
:
if
experiment_dict
[
key
]
[
'port'
]
==
args
.
port
:
if
experiment_dict
[
key
]
.
get
(
'port'
)
==
args
.
port
:
result_list
.
append
(
key
)
result_list
.
append
(
key
)
if
args
.
id
and
result_list
and
args
.
id
!=
result_list
[
0
]:
if
args
.
id
and
result_list
and
args
.
id
!=
result_list
[
0
]:
print_error
(
'Experiment id and resful server port not match'
)
print_error
(
'Experiment id and resful server port not match'
)
...
@@ -143,10 +141,10 @@ def parse_ids(args):
...
@@ -143,10 +141,10 @@ def parse_ids(args):
experiment_information
+=
EXPERIMENT_DETAIL_FORMAT
%
(
key
,
experiment_information
+=
EXPERIMENT_DETAIL_FORMAT
%
(
key
,
experiment_dict
[
key
].
get
(
'experimentName'
,
'N/A'
),
experiment_dict
[
key
].
get
(
'experimentName'
,
'N/A'
),
experiment_dict
[
key
][
'status'
],
experiment_dict
[
key
][
'status'
],
experiment_dict
[
key
]
[
'port'
]
,
experiment_dict
[
key
]
.
get
(
'port'
,
'N/A'
)
,
experiment_dict
[
key
].
get
(
'platform'
),
experiment_dict
[
key
].
get
(
'platform'
),
experiment_dict
[
key
][
'startTime'
],
time
.
strftime
(
'%Y-%m-%d %H:%M:%S'
,
time
.
localtime
(
experiment_dict
[
key
][
'startTime'
]
/
1000
))
if
isinstance
(
experiment_dict
[
key
][
'startTime'
],
int
)
else
experiment_dict
[
key
][
'startTime'
],
experiment_dict
[
key
][
'endTime'
])
time
.
strftime
(
'%Y-%m-%d %H:%M:%S'
,
time
.
localtime
(
experiment_dict
[
key
][
'endTime'
]
/
1000
))
if
isinstance
(
experiment_dict
[
key
][
'endTime'
],
int
)
else
experiment_dict
[
key
][
'endTime'
])
print
(
EXPERIMENT_INFORMATION_FORMAT
%
experiment_information
)
print
(
EXPERIMENT_INFORMATION_FORMAT
%
experiment_information
)
exit
(
1
)
exit
(
1
)
else
:
else
:
...
@@ -186,7 +184,7 @@ def get_experiment_port(args):
...
@@ -186,7 +184,7 @@ def get_experiment_port(args):
exit
(
1
)
exit
(
1
)
experiment_config
=
Experiments
()
experiment_config
=
Experiments
()
experiment_dict
=
experiment_config
.
get_all_experiments
()
experiment_dict
=
experiment_config
.
get_all_experiments
()
return
experiment_dict
[
experiment_id
]
[
'port'
]
return
experiment_dict
[
experiment_id
]
.
get
(
'port'
)
def
convert_time_stamp_to_date
(
content
):
def
convert_time_stamp_to_date
(
content
):
'''Convert time stamp to date time format'''
'''Convert time stamp to date time format'''
...
@@ -202,8 +200,9 @@ def convert_time_stamp_to_date(content):
...
@@ -202,8 +200,9 @@ def convert_time_stamp_to_date(content):
def
check_rest
(
args
):
def
check_rest
(
args
):
'''check if restful server is running'''
'''check if restful server is running'''
nni_config
=
Config
(
get_config_filename
(
args
))
experiment_config
=
Experiments
()
rest_port
=
nni_config
.
get_config
(
'restServerPort'
)
experiment_dict
=
experiment_config
.
get_all_experiments
()
rest_port
=
experiment_dict
.
get
(
get_config_filename
(
args
)).
get
(
'port'
)
running
,
_
=
check_rest_server_quick
(
rest_port
)
running
,
_
=
check_rest_server_quick
(
rest_port
)
if
running
:
if
running
:
print_normal
(
'Restful server is running...'
)
print_normal
(
'Restful server is running...'
)
...
@@ -220,18 +219,19 @@ def stop_experiment(args):
...
@@ -220,18 +219,19 @@ def stop_experiment(args):
if
experiment_id_list
:
if
experiment_id_list
:
for
experiment_id
in
experiment_id_list
:
for
experiment_id
in
experiment_id_list
:
print_normal
(
'Stopping experiment %s'
%
experiment_id
)
print_normal
(
'Stopping experiment %s'
%
experiment_id
)
nni_config
=
Config
(
experiment_id
)
experiment_config
=
Experiments
()
rest_pid
=
nni_config
.
get_config
(
'restServerPid'
)
experiment_dict
=
experiment_config
.
get_all_experiments
()
rest_pid
=
experiment_dict
.
get
(
experiment_id
).
get
(
'pid'
)
if
rest_pid
:
if
rest_pid
:
kill_command
(
rest_pid
)
kill_command
(
rest_pid
)
tensorboard_pid_list
=
nni_config
.
get_config
(
'tensorboardPidList'
)
tensorboard_pid_list
=
experiment_dict
.
get
(
experiment_id
).
get
(
'tensorboardPidList'
)
if
tensorboard_pid_list
:
if
tensorboard_pid_list
:
for
tensorboard_pid
in
tensorboard_pid_list
:
for
tensorboard_pid
in
tensorboard_pid_list
:
try
:
try
:
kill_command
(
tensorboard_pid
)
kill_command
(
tensorboard_pid
)
except
Exception
as
exception
:
except
Exception
as
exception
:
print_error
(
exception
)
print_error
(
exception
)
nni_config
.
set_config
(
'tensorboardPidList'
,
[])
experiment_config
.
update_experiment
(
experiment_id
,
'tensorboardPidList'
,
[])
print_normal
(
'Stop experiment success.'
)
print_normal
(
'Stop experiment success.'
)
def
trial_ls
(
args
):
def
trial_ls
(
args
):
...
@@ -250,9 +250,10 @@ def trial_ls(args):
...
@@ -250,9 +250,10 @@ def trial_ls(args):
if
args
.
head
and
args
.
tail
:
if
args
.
head
and
args
.
tail
:
print_error
(
'Head and tail cannot be set at the same time.'
)
print_error
(
'Head and tail cannot be set at the same time.'
)
return
return
nni_config
=
Config
(
get_config_filename
(
args
))
experiment_config
=
Experiments
()
rest_port
=
nni_config
.
get_config
(
'restServerPort'
)
experiment_dict
=
experiment_config
.
get_all_experiments
()
rest_pid
=
nni_config
.
get_config
(
'restServerPid'
)
rest_port
=
experiment_dict
.
get
(
get_config_filename
(
args
)).
get
(
'port'
)
rest_pid
=
experiment_dict
.
get
(
get_config_filename
(
args
)).
get
(
'pid'
)
if
not
detect_process
(
rest_pid
):
if
not
detect_process
(
rest_pid
):
print_error
(
'Experiment is not running...'
)
print_error
(
'Experiment is not running...'
)
return
return
...
@@ -281,9 +282,10 @@ def trial_ls(args):
...
@@ -281,9 +282,10 @@ def trial_ls(args):
def
trial_kill
(
args
):
def
trial_kill
(
args
):
'''List trial'''
'''List trial'''
nni_config
=
Config
(
get_config_filename
(
args
))
experiment_config
=
Experiments
()
rest_port
=
nni_config
.
get_config
(
'restServerPort'
)
experiment_dict
=
experiment_config
.
get_all_experiments
()
rest_pid
=
nni_config
.
get_config
(
'restServerPid'
)
rest_port
=
experiment_dict
.
get
(
get_config_filename
(
args
)).
get
(
'port'
)
rest_pid
=
experiment_dict
.
get
(
get_config_filename
(
args
)).
get
(
'pid'
)
if
not
detect_process
(
rest_pid
):
if
not
detect_process
(
rest_pid
):
print_error
(
'Experiment is not running...'
)
print_error
(
'Experiment is not running...'
)
return
return
...
@@ -312,9 +314,10 @@ def trial_codegen(args):
...
@@ -312,9 +314,10 @@ def trial_codegen(args):
def
list_experiment
(
args
):
def
list_experiment
(
args
):
'''Get experiment information'''
'''Get experiment information'''
nni_config
=
Config
(
get_config_filename
(
args
))
experiment_config
=
Experiments
()
rest_port
=
nni_config
.
get_config
(
'restServerPort'
)
experiment_dict
=
experiment_config
.
get_all_experiments
()
rest_pid
=
nni_config
.
get_config
(
'restServerPid'
)
rest_port
=
experiment_dict
.
get
(
get_config_filename
(
args
)).
get
(
'port'
)
rest_pid
=
experiment_dict
.
get
(
get_config_filename
(
args
)).
get
(
'pid'
)
if
not
detect_process
(
rest_pid
):
if
not
detect_process
(
rest_pid
):
print_error
(
'Experiment is not running...'
)
print_error
(
'Experiment is not running...'
)
return
return
...
@@ -333,8 +336,9 @@ def list_experiment(args):
...
@@ -333,8 +336,9 @@ def list_experiment(args):
def
experiment_status
(
args
):
def
experiment_status
(
args
):
'''Show the status of experiment'''
'''Show the status of experiment'''
nni_config
=
Config
(
get_config_filename
(
args
))
experiment_config
=
Experiments
()
rest_port
=
nni_config
.
get_config
(
'restServerPort'
)
experiment_dict
=
experiment_config
.
get_all_experiments
()
rest_port
=
experiment_dict
.
get
(
get_config_filename
(
args
)).
get
(
'port'
)
result
,
response
=
check_rest_server_quick
(
rest_port
)
result
,
response
=
check_rest_server_quick
(
rest_port
)
if
not
result
:
if
not
result
:
print_normal
(
'Restful server is not running...'
)
print_normal
(
'Restful server is not running...'
)
...
@@ -620,12 +624,12 @@ def platform_clean(args):
...
@@ -620,12 +624,12 @@ def platform_clean(args):
break
break
if
platform
==
'remote'
:
if
platform
==
'remote'
:
machine_list
=
config_content
.
get
(
'machineList'
)
machine_list
=
config_content
.
get
(
'machineList'
)
remote_clean
(
machine_list
,
None
)
remote_clean
(
machine_list
)
elif
platform
==
'pai'
:
elif
platform
==
'pai'
:
host
=
config_content
.
get
(
'paiConfig'
).
get
(
'host'
)
host
=
config_content
.
get
(
'paiConfig'
).
get
(
'host'
)
user_name
=
config_content
.
get
(
'paiConfig'
).
get
(
'userName'
)
user_name
=
config_content
.
get
(
'paiConfig'
).
get
(
'userName'
)
output_dir
=
config_content
.
get
(
'trial'
).
get
(
'outputDir'
)
output_dir
=
config_content
.
get
(
'trial'
).
get
(
'outputDir'
)
hdfs_clean
(
host
,
user_name
,
output_dir
,
None
)
hdfs_clean
(
host
,
user_name
,
output_dir
)
print_normal
(
'Done.'
)
print_normal
(
'Done.'
)
def
experiment_list
(
args
):
def
experiment_list
(
args
):
...
@@ -651,7 +655,7 @@ def experiment_list(args):
...
@@ -651,7 +655,7 @@ def experiment_list(args):
experiment_information
+=
EXPERIMENT_DETAIL_FORMAT
%
(
key
,
experiment_information
+=
EXPERIMENT_DETAIL_FORMAT
%
(
key
,
experiment_dict
[
key
].
get
(
'experimentName'
,
'N/A'
),
experiment_dict
[
key
].
get
(
'experimentName'
,
'N/A'
),
experiment_dict
[
key
][
'status'
],
experiment_dict
[
key
][
'status'
],
experiment_dict
[
key
]
[
'port'
]
,
experiment_dict
[
key
]
.
get
(
'port'
,
'N/A'
)
,
experiment_dict
[
key
].
get
(
'platform'
),
experiment_dict
[
key
].
get
(
'platform'
),
time
.
strftime
(
'%Y-%m-%d %H:%M:%S'
,
time
.
localtime
(
experiment_dict
[
key
][
'startTime'
]
/
1000
))
if
isinstance
(
experiment_dict
[
key
][
'startTime'
],
int
)
else
experiment_dict
[
key
][
'startTime'
],
time
.
strftime
(
'%Y-%m-%d %H:%M:%S'
,
time
.
localtime
(
experiment_dict
[
key
][
'startTime'
]
/
1000
))
if
isinstance
(
experiment_dict
[
key
][
'startTime'
],
int
)
else
experiment_dict
[
key
][
'startTime'
],
time
.
strftime
(
'%Y-%m-%d %H:%M:%S'
,
time
.
localtime
(
experiment_dict
[
key
][
'endTime'
]
/
1000
))
if
isinstance
(
experiment_dict
[
key
][
'endTime'
],
int
)
else
experiment_dict
[
key
][
'endTime'
])
time
.
strftime
(
'%Y-%m-%d %H:%M:%S'
,
time
.
localtime
(
experiment_dict
[
key
][
'endTime'
]
/
1000
))
if
isinstance
(
experiment_dict
[
key
][
'endTime'
],
int
)
else
experiment_dict
[
key
][
'endTime'
])
...
@@ -752,9 +756,10 @@ def export_trials_data(args):
...
@@ -752,9 +756,10 @@ def export_trials_data(args):
groupby
.
setdefault
(
content
[
'trialJobId'
],
[]).
append
(
json
.
loads
(
content
[
'data'
]))
groupby
.
setdefault
(
content
[
'trialJobId'
],
[]).
append
(
json
.
loads
(
content
[
'data'
]))
return
groupby
return
groupby
nni_config
=
Config
(
get_config_filename
(
args
))
experiment_config
=
Experiments
()
rest_port
=
nni_config
.
get_config
(
'restServerPort'
)
experiment_dict
=
experiment_config
.
get_all_experiments
()
rest_pid
=
nni_config
.
get_config
(
'restServerPid'
)
rest_port
=
experiment_dict
.
get
(
get_config_filename
(
args
)).
get
(
'port'
)
rest_pid
=
experiment_dict
.
get
(
get_config_filename
(
args
)).
get
(
'pid'
)
if
not
detect_process
(
rest_pid
):
if
not
detect_process
(
rest_pid
):
print_error
(
'Experiment is not running...'
)
print_error
(
'Experiment is not running...'
)
...
...
nni/tools/nnictl/tensorboard_utils.py
View file @
349ead41
...
@@ -70,7 +70,7 @@ def format_tensorboard_log_path(path_list):
...
@@ -70,7 +70,7 @@ def format_tensorboard_log_path(path_list):
new_path_list
.
append
(
'name%d:%s'
%
(
index
+
1
,
value
))
new_path_list
.
append
(
'name%d:%s'
%
(
index
+
1
,
value
))
return
','
.
join
(
new_path_list
)
return
','
.
join
(
new_path_list
)
def
start_tensorboard_process
(
args
,
nni_config
,
path_list
,
temp_nni_path
):
def
start_tensorboard_process
(
args
,
experiment_id
,
path_list
,
temp_nni_path
):
'''call cmds to start tensorboard process in local machine'''
'''call cmds to start tensorboard process in local machine'''
if
detect_port
(
args
.
port
):
if
detect_port
(
args
.
port
):
print_error
(
'Port %s is used by another process, please reset port!'
%
str
(
args
.
port
))
print_error
(
'Port %s is used by another process, please reset port!'
%
str
(
args
.
port
))
...
@@ -83,20 +83,19 @@ def start_tensorboard_process(args, nni_config, path_list, temp_nni_path):
...
@@ -83,20 +83,19 @@ def start_tensorboard_process(args, nni_config, path_list, temp_nni_path):
url_list
=
get_local_urls
(
args
.
port
)
url_list
=
get_local_urls
(
args
.
port
)
print_green
(
'Start tensorboard success!'
)
print_green
(
'Start tensorboard success!'
)
print_normal
(
'Tensorboard urls: '
+
' '
.
join
(
url_list
))
print_normal
(
'Tensorboard urls: '
+
' '
.
join
(
url_list
))
tensorboard_process_pid_list
=
nni_config
.
get_config
(
'tensorboardPidList'
)
experiment_config
=
Experiments
()
tensorboard_process_pid_list
=
experiment_config
.
get_all_experiments
().
get
(
experiment_id
).
get
(
'tensorboardPidList'
)
if
tensorboard_process_pid_list
is
None
:
if
tensorboard_process_pid_list
is
None
:
tensorboard_process_pid_list
=
[
tensorboard_process
.
pid
]
tensorboard_process_pid_list
=
[
tensorboard_process
.
pid
]
else
:
else
:
tensorboard_process_pid_list
.
append
(
tensorboard_process
.
pid
)
tensorboard_process_pid_list
.
append
(
tensorboard_process
.
pid
)
nni_config
.
set_config
(
'tensorboardPidList'
,
tensorboard_process_pid_list
)
experiment_config
.
update_experiment
(
experiment_id
,
'tensorboardPidList'
,
tensorboard_process_pid_list
)
def
stop_tensorboard
(
args
):
def
stop_tensorboard
(
args
):
'''stop tensorboard'''
'''stop tensorboard'''
experiment_id
=
check_experiment_id
(
args
)
experiment_id
=
check_experiment_id
(
args
)
experiment_config
=
Experiments
()
experiment_config
=
Experiments
()
experiment_dict
=
experiment_config
.
get_all_experiments
()
tensorboard_pid_list
=
experiment_config
.
get_all_experiments
().
get
(
experiment_id
).
get
(
'tensorboardPidList'
)
nni_config
=
Config
(
experiment_id
)
tensorboard_pid_list
=
nni_config
.
get_config
(
'tensorboardPidList'
)
if
tensorboard_pid_list
:
if
tensorboard_pid_list
:
for
tensorboard_pid
in
tensorboard_pid_list
:
for
tensorboard_pid
in
tensorboard_pid_list
:
try
:
try
:
...
@@ -104,7 +103,7 @@ def stop_tensorboard(args):
...
@@ -104,7 +103,7 @@ def stop_tensorboard(args):
call
(
cmds
)
call
(
cmds
)
except
Exception
as
exception
:
except
Exception
as
exception
:
print_error
(
exception
)
print_error
(
exception
)
nni_config
.
set_config
(
'tensorboardPidList'
,
[])
experiment_config
.
update_experiment
(
experiment_id
,
'tensorboardPidList'
,
[])
print_normal
(
'Stop tensorboard success!'
)
print_normal
(
'Stop tensorboard success!'
)
else
:
else
:
print_error
(
'No tensorboard configuration!'
)
print_error
(
'No tensorboard configuration!'
)
...
@@ -164,4 +163,4 @@ def start_tensorboard(args):
...
@@ -164,4 +163,4 @@ def start_tensorboard(args):
os
.
makedirs
(
temp_nni_path
,
exist_ok
=
True
)
os
.
makedirs
(
temp_nni_path
,
exist_ok
=
True
)
path_list
=
get_path_list
(
args
,
nni_config
,
trial_content
,
temp_nni_path
)
path_list
=
get_path_list
(
args
,
nni_config
,
trial_content
,
temp_nni_path
)
start_tensorboard_process
(
args
,
nni_config
,
path_list
,
temp_nni_path
)
start_tensorboard_process
(
args
,
experiment_id
,
path_list
,
temp_nni_path
)
\ No newline at end of file
nni/tools/nnictl/updater.py
View file @
349ead41
...
@@ -5,7 +5,7 @@ import json
...
@@ -5,7 +5,7 @@ import json
import
os
import
os
from
.rest_utils
import
rest_put
,
rest_post
,
rest_get
,
check_rest_server_quick
,
check_response
from
.rest_utils
import
rest_put
,
rest_post
,
rest_get
,
check_rest_server_quick
,
check_response
from
.url_utils
import
experiment_url
,
import_data_url
from
.url_utils
import
experiment_url
,
import_data_url
from
.config_utils
import
Config
from
.config_utils
import
Config
,
Experiments
from
.common_utils
import
get_json_content
,
print_normal
,
print_error
,
print_warning
from
.common_utils
import
get_json_content
,
print_normal
,
print_error
,
print_warning
from
.nnictl_utils
import
get_experiment_port
,
get_config_filename
,
detect_process
from
.nnictl_utils
import
get_experiment_port
,
get_config_filename
,
detect_process
from
.launcher_utils
import
parse_time
from
.launcher_utils
import
parse_time
...
@@ -58,8 +58,9 @@ def get_query_type(key):
...
@@ -58,8 +58,9 @@ def get_query_type(key):
def
update_experiment_profile
(
args
,
key
,
value
):
def
update_experiment_profile
(
args
,
key
,
value
):
'''call restful server to update experiment profile'''
'''call restful server to update experiment profile'''
nni_config
=
Config
(
get_config_filename
(
args
))
experiment_config
=
Experiments
()
rest_port
=
nni_config
.
get_config
(
'restServerPort'
)
experiment_dict
=
experiment_config
.
get_all_experiments
()
rest_port
=
experiment_dict
.
get
(
get_config_filename
(
args
)).
get
(
'port'
)
running
,
_
=
check_rest_server_quick
(
rest_port
)
running
,
_
=
check_rest_server_quick
(
rest_port
)
if
running
:
if
running
:
response
=
rest_get
(
experiment_url
(
rest_port
),
REST_TIME_OUT
)
response
=
rest_get
(
experiment_url
(
rest_port
),
REST_TIME_OUT
)
...
...
pipelines/fast-test.yml
View file @
349ead41
...
@@ -4,27 +4,26 @@
...
@@ -4,27 +4,26 @@
jobs
:
jobs
:
-
job
:
ubuntu_latest
-
job
:
ubuntu_latest
pool
:
pool
:
# FIXME: In ubuntu-20.04 Python interpreter crashed during SMAC UT
vmImage
:
ubuntu-latest
vmImage
:
ubuntu-18.04
# This platform tests lint and doc first.
# This platform tests lint and doc first.
steps
:
steps
:
-
task
:
UsePythonVersion@0
-
task
:
UsePythonVersion@0
inputs
:
inputs
:
versionSpec
:
3.
6
versionSpec
:
3.
8
displayName
:
Configure Python version
displayName
:
Configure Python version
-
script
:
|
-
script
:
|
set -e
set -e
python
3
-m pip install --upgrade pip setuptools
python -m pip install --upgrade pip setuptools
python
3
-m pip install pytest coverage
python -m pip install pytest coverage
python
3
-m pip install pylint flake8
python -m pip install pylint flake8
echo "##vso[task.setvariable variable=PATH]${HOME}/.local/bin:${PATH}"
echo "##vso[task.setvariable variable=PATH]${HOME}/.local/bin:${PATH}"
displayName
:
Install Python tools
displayName
:
Install Python tools
-
script
:
|
-
script
:
|
python
3
setup.py develop
python setup.py develop
displayName
:
Install NNI
displayName
:
Install NNI
-
script
:
|
-
script
:
|
...
@@ -35,24 +34,28 @@ jobs:
...
@@ -35,24 +34,28 @@ jobs:
yarn eslint
yarn eslint
displayName
:
ESLint
displayName
:
ESLint
# FIXME: temporarily fixed to pytorch 1.6 as 1.7 won't work with compression
-
script
:
|
-
script
:
|
set -e
set -e
sudo apt-get install -y pandoc
sudo apt-get install -y pandoc
python3 -m pip install --upgrade pygments
python -m pip install --upgrade pygments
python3 -m pip install --upgrade torch>=1.7.0+cpu torchvision>=0.8.1+cpu -f https://download.pytorch.org/whl/torch_stable.html
python -m pip install "torch==1.6.0+cpu" "torchvision==0.7.0+cpu" -f https://download.pytorch.org/whl/torch_stable.html
python3 -m pip install --upgrade tensorflow
python -m pip install tensorflow
python3 -m pip install --upgrade gym onnx peewee thop graphviz
python -m pip install gym onnx peewee thop graphviz
python3 -m pip install sphinx==1.8.3 sphinx-argparse==0.2.5 sphinx-markdown-tables==0.0.9 sphinx-rtd-theme==0.4.2 sphinxcontrib-websupport==1.1.0 recommonmark==0.5.0 nbsphinx
python -m pip install sphinx==3.3.1 sphinx-argparse==0.2.5 sphinx-rtd-theme==0.4.2 sphinxcontrib-websupport==1.1.0 nbsphinx
sudo apt-get install swig -y
sudo apt-get remove swig -y
python3 -m pip install -e .[SMAC,BOHB]
sudo apt-get install swig3.0 -y
sudo ln -s /usr/bin/swig3.0 /usr/bin/swig
python -m pip install -e .[SMAC,BOHB]
displayName
:
Install extra dependencies
displayName
:
Install extra dependencies
-
script
:
|
-
script
:
|
set -e
set -e
python
3
-m pylint --rcfile pylintrc nni
python -m pylint --rcfile pylintrc nni
python
3
-m flake8 nni --count --select=E9,F63,F72,F82 --show-source --statistics
python -m flake8 nni --count --select=E9,F63,F72,F82 --show-source --statistics
EXCLUDES=examples/trials/mnist-nas/*/mnist*.py,examples/trials/nas_cifar10/src/cifar10/general_child.py
EXCLUDES=examples/trials/mnist-nas/*/mnist*.py,examples/trials/nas_cifar10/src/cifar10/general_child.py
python
3
-m flake8 examples --count --exclude=$EXCLUDES --select=E9,F63,F72,F82 --show-source --statistics
python -m flake8 examples --count --exclude=$EXCLUDES --select=E9,F63,F72,F82 --show-source --statistics
displayName
:
pylint and flake8
displayName
:
pylint and flake8
-
script
:
|
-
script
:
|
...
@@ -61,10 +64,14 @@ jobs:
...
@@ -61,10 +64,14 @@ jobs:
displayName
:
Check Sphinx documentation
displayName
:
Check Sphinx documentation
-
script
:
|
-
script
:
|
set -e
cd test
cd test
python3 -m pytest ut --ignore=ut/sdk/test_pruners.py --ignore=ut/sdk/test_compressor_tf.py
python -m pytest ut --ignore=ut/sdk/test_pruners.py \
python3 -m pytest ut/sdk/test_pruners.py
--ignore=ut/sdk/test_compressor_tf.py \
python3 -m pytest ut/sdk/test_compressor_tf.py
--ignore=ut/sdk/test_compressor_torch.py
python -m pytest ut/sdk/test_pruners.py
python -m pytest ut/sdk/test_compressor_tf.py
python -m pytest ut/sdk/test_compressor_torch.py
displayName
:
Python unit test
displayName
:
Python unit test
-
script
:
|
-
script
:
|
...
@@ -77,7 +84,7 @@ jobs:
...
@@ -77,7 +84,7 @@ jobs:
-
script
:
|
-
script
:
|
cd test
cd test
python
3
nni_test/nnitest/run_tests.py --config config/pr_tests.yml
python nni_test/nnitest/run_tests.py --config config/pr_tests.yml
displayName
:
Simple integration test
displayName
:
Simple integration test
...
...
pipelines/integration-test-adl.yml
0 → 100644
View file @
349ead41
trigger
:
none
pr
:
none
schedules
:
-
cron
:
0 16 * * *
branches
:
include
:
[
master
]
jobs
:
-
job
:
adl
pool
:
NNI CI KUBE CLI
timeoutInMinutes
:
120
steps
:
-
script
:
|
export NNI_RELEASE=999.$(date -u +%Y%m%d%H%M%S)
echo "##vso[task.setvariable variable=PATH]${PATH}:${HOME}/.local/bin"
echo "##vso[task.setvariable variable=NNI_RELEASE]${NNI_RELEASE}"
echo "Working directory: ${PWD}"
echo "NNI version: ${NNI_RELEASE}"
echo "Build docker image: $(build_docker_image)"
python3 -m pip install --upgrade pip setuptools
displayName
:
Prepare
-
script
:
|
set -e
python3 setup.py build_ts
python3 setup.py bdist_wheel -p manylinux1_x86_64
python3 -m pip install dist/nni-${NNI_RELEASE}-py3-none-manylinux1_x86_64.whl[SMAC,BOHB]
displayName
:
Build and install NNI
-
script
:
|
set -e
cd examples/tuners/customized_tuner
python3 setup.py develop --user
nnictl algo register --meta meta_file.yml
displayName
:
Install customized tuner
-
script
:
|
set -e
docker login -u nnidev -p $(docker_hub_password)
sed -i '$a RUN python3 -m pip install adaptdl tensorboard' Dockerfile
sed -i '$a COPY examples /examples' Dockerfile
sed -i '$a COPY test /test' Dockerfile
echo '## Build docker image ##'
docker build --build-arg NNI_RELEASE=${NNI_RELEASE} -t nnidev/nni-nightly .
echo '## Upload docker image ##'
docker push nnidev/nni-nightly
condition
:
eq(variables['build_docker_image'], 'true')
displayName
:
Build and upload docker image
-
script
:
|
set -e
cd test
python3 nni_test/nnitest/generate_ts_config.py \
--ts adl \
--nni_docker_image nnidev/nni-nightly \
--checkpoint_storage_class $(checkpoint_storage_class) \
--checkpoint_storage_size $(checkpoint_storage_size) \
--nni_manager_ip $(nni_manager_ip)
python3 nni_test/nnitest/run_tests.py --config config/integration_tests.yml --ts adl
displayName
:
Integration test
pipelines/release.yml
View file @
349ead41
# Copyright (c) Microsoft Corporation.
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# Licensed under the MIT license.
trigger
:
none
pr
:
none
jobs
:
jobs
:
-
job
:
validate_version_number
-
job
:
validate_version_number
pool
:
pool
:
...
@@ -13,9 +16,11 @@ jobs:
...
@@ -13,9 +16,11 @@ jobs:
displayName
:
Configure Python version
displayName
:
Configure Python version
-
script
:
|
-
script
:
|
echo $(build_type)
echo $(NNI_RELEASE)
export BRANCH_TAG=`git describe --tags --abbrev=0`
export BRANCH_TAG=`git describe --tags --abbrev=0`
echo $BRANCH_TAG
echo $BRANCH_TAG
if [[ $BRANCH_TAG = v$(NNI_RELEASE) && $(NNI_RELEASE) =~ ^
v
[0-9](.[0-9])+$ ]]; then
if [[ $BRANCH_TAG =
=
v$(NNI_RELEASE) && $(NNI_RELEASE) =~ ^[0-9](.[0-9])+$ ]]; then
echo 'Build version match branch tag'
echo 'Build version match branch tag'
else
else
echo 'Build version does not match branch tag'
echo 'Build version does not match branch tag'
...
@@ -25,15 +30,16 @@ jobs:
...
@@ -25,15 +30,16 @@ jobs:
displayName
:
Validate release version number and branch tag
displayName
:
Validate release version number and branch tag
-
script
:
|
-
script
:
|
echo $(build_type)
echo $(NNI_RELEASE)
echo $(NNI_RELEASE)
if [[ $(NNI_RELEASE) =~ ^[0-9](.[0-9])+
a
[0-9]$ ]]; then
if [[ $(NNI_RELEASE) =~ ^[0-9](.[0-9])+
(a|b|rc)
[0-9]$ ]]; then
echo 'Valid prerelease version $(NNI_RELEASE)'
echo 'Valid prerelease version $(NNI_RELEASE)'
echo `git describe --tags --abbrev=0`
echo `git describe --tags --abbrev=0`
else
else
echo 'Invalid build version $(NNI_RELEASE)'
echo 'Invalid build version $(NNI_RELEASE)'
exit 1
exit 1
fi
fi
condition
:
ne( variables['build_type'], '
re
release' )
condition
:
ne( variables['build_type'], 'release' )
displayName
:
Validate prerelease version number
displayName
:
Validate prerelease version number
-
job
:
linux
-
job
:
linux
...
@@ -49,22 +55,22 @@ jobs:
...
@@ -49,22 +55,22 @@ jobs:
displayName
:
Configure Python version
displayName
:
Configure Python version
-
script
:
|
-
script
:
|
python -m pip install --upgrade pip setuptools twine
python -m pip install --upgrade pip setuptools
wheel
twine
python test/vso_tools/build_wheel.py $(NNI_RELEASE)
python test/vso_tools/build_wheel.py $(NNI_RELEASE)
displayName
:
Build wheel
if [ $(build_type) = 'release' ]
-
script
:
|
echo 'uploading release package to pypi...'
if [[ $(build_type) == 'release' || $(build_type) == 'rc' ]]; then
echo 'uploading to pypi...'
python -m twine upload -u nni -p $(pypi_password) dist/*
python -m twine upload -u nni -p $(pypi_password) dist/*
then
else
else
echo 'uploading
prerelease package
to testpypi...'
echo 'uploading to testpypi...'
python -m twine upload -u nni -p $(pypi_password) --repository-url https://test.pypi.org/legacy/ dist/*
python -m twine upload -u nni -p $(pypi_password) --repository-url https://test.pypi.org/legacy/ dist/*
fi
fi
displayName
:
Build and u
pload wheel
displayName
:
U
pload wheel
-
script
:
|
-
script
:
|
if [ $(build_type) = 'release' ]
if [[ $(build_type) == 'release' || $(build_type) == 'rc' ]]; then
then
docker login -u msranni -p $(docker_hub_password)
docker login -u msranni -p $(docker_hub_password)
export IMAGE_NAME=msranni/nni
export IMAGE_NAME=msranni/nni
else
else
...
@@ -74,9 +80,11 @@ jobs:
...
@@ -74,9 +80,11 @@ jobs:
echo "## Building ${IMAGE_NAME}:$(NNI_RELEASE) ##"
echo "## Building ${IMAGE_NAME}:$(NNI_RELEASE) ##"
docker build --build-arg NNI_RELEASE=$(NNI_RELEASE) -t ${IMAGE_NAME} .
docker build --build-arg NNI_RELEASE=$(NNI_RELEASE) -t ${IMAGE_NAME} .
docker tag ${IMAGE_NAME} ${IMAGE_NAME}:$(NNI_RELEASE)
docker tag ${IMAGE_NAME} ${IMAGE_NAME}:v$(NNI_RELEASE)
docker push ${IMAGE_NAME}
docker push ${IMAGE_NAME}:v$(NNI_RELEASE)
docker push ${IMAGE_NAME}:$(NNI_RELEASE)
if [[ $(build_type) != 'rc' ]]; then
docker push ${IMAGE_NAME}
fi
displayName
:
Build and upload docker image
displayName
:
Build and upload docker image
-
job
:
macos
-
job
:
macos
...
@@ -92,18 +100,19 @@ jobs:
...
@@ -92,18 +100,19 @@ jobs:
displayName
:
Configure Python version
displayName
:
Configure Python version
-
script
:
|
-
script
:
|
python -m pip install --upgrade pip setuptools twine
python -m pip install --upgrade pip setuptools
wheel
twine
python test/vso_tools/build_wheel.py $(NNI_RELEASE)
python test/vso_tools/build_wheel.py $(NNI_RELEASE)
displayName
:
Build wheel
if [ $(build_type) = 'release' ]
-
script
:
|
if [[ $(build_type) == 'release' || $(build_type) == 'rc' ]]; then
echo '## uploading to pypi ##'
echo '## uploading to pypi ##'
python -m twine upload -u nni -p $(pypi_password) dist/*
python -m twine upload -u nni -p $(pypi_password) dist/*
then
else
else
echo '## uploading to testpypi ##'
echo '## uploading to testpypi ##'
python -m twine upload -u nni -p $(pypi_password) --repository-url https://test.pypi.org/legacy/ dist/*
python -m twine upload -u nni -p $(pypi_password) --repository-url https://test.pypi.org/legacy/ dist/*
fi
fi
displayName
:
Build and u
pload wheel
displayName
:
U
pload wheel
-
job
:
windows
-
job
:
windows
dependsOn
:
validate_version_number
dependsOn
:
validate_version_number
...
@@ -118,15 +127,16 @@ jobs:
...
@@ -118,15 +127,16 @@ jobs:
displayName
:
Configure Python version
displayName
:
Configure Python version
-
powershell
:
|
-
powershell
:
|
python -m pip install --upgrade pip setuptools twine
python -m pip install --upgrade pip setuptools
wheel
twine
python test/vso_tools/build_wheel.py $(NNI_RELEASE)
python test/vso_tools/build_wheel.py $(NNI_RELEASE)
displayName
:
Build wheel
if($env:BUILD_TYPE -eq 'release'){
-
powershell
:
|
if ($env:BUILD_TYPE -eq 'release' -Or $env:BUILD_TYPE -eq 'rc') {
Write-Host '## uploading to pypi ##'
Write-Host '## uploading to pypi ##'
python -m twine upload -u nni -p $(pypi_password) dist/*
python -m twine upload -u nni -p $(pypi_password) dist/*
}
} else {
else{
Write-Host '## uploading to testpypi ##'
Write-Host '## uploading to testpypi ##'
python -m twine upload -u nni -p $(pypi_password) --repository-url https://test.pypi.org/legacy/ dist/*
python -m twine upload -u nni -p $(pypi_password) --repository-url https://test.pypi.org/legacy/ dist/*
}
}
displayName
:
Build and u
pload wheel
displayName
:
U
pload wheel
setup.py
View file @
349ead41
...
@@ -69,7 +69,6 @@ dependencies = [
...
@@ -69,7 +69,6 @@ dependencies = [
'PythonWebHDFS'
,
'PythonWebHDFS'
,
'colorama'
,
'colorama'
,
'scikit-learn>=0.23.2'
,
'scikit-learn>=0.23.2'
,
'pkginfo'
,
'websockets'
,
'websockets'
,
'filelock'
,
'filelock'
,
'prettytable'
,
'prettytable'
,
...
@@ -112,11 +111,8 @@ def _setup():
...
@@ -112,11 +111,8 @@ def _setup():
python_requires
=
'>=3.6'
,
python_requires
=
'>=3.6'
,
install_requires
=
dependencies
,
install_requires
=
dependencies
,
extras_require
=
{
extras_require
=
{
'SMAC'
:
[
'SMAC'
:
[
'ConfigSpaceNNI'
,
'smac4nni'
],
'ConfigSpaceNNI @ git+https://github.com/QuanluZhang/ConfigSpace.git'
,
'BOHB'
:
[
'ConfigSpace==0.4.7'
,
'statsmodels==0.12.0'
],
'smac @ git+https://github.com/QuanluZhang/SMAC3.git'
],
'BOHB'
:
[
'ConfigSpace==0.4.7'
,
'statsmodels==0.10.0'
],
'PPOTuner'
:
[
'enum34'
,
'gym'
]
'PPOTuner'
:
[
'enum34'
,
'gym'
]
},
},
setup_requires
=
[
'requests'
],
setup_requires
=
[
'requests'
],
...
@@ -189,6 +185,7 @@ class Build(build):
...
@@ -189,6 +185,7 @@ class Build(build):
sys
.
exit
(
'Please set environment variable "NNI_RELEASE=<release_version>"'
)
sys
.
exit
(
'Please set environment variable "NNI_RELEASE=<release_version>"'
)
if
os
.
path
.
islink
(
'nni_node/main.js'
):
if
os
.
path
.
islink
(
'nni_node/main.js'
):
sys
.
exit
(
'A development build already exists. Please uninstall NNI and run "python3 setup.py clean --all".'
)
sys
.
exit
(
'A development build already exists. Please uninstall NNI and run "python3 setup.py clean --all".'
)
open
(
'nni/version.py'
,
'w'
).
write
(
f
"__version__ = '
{
release
}
'"
)
super
().
run
()
super
().
run
()
class
Develop
(
develop
):
class
Develop
(
develop
):
...
@@ -212,6 +209,7 @@ class Develop(develop):
...
@@ -212,6 +209,7 @@ class Develop(develop):
super
().
finalize_options
()
super
().
finalize_options
()
def
run
(
self
):
def
run
(
self
):
open
(
'nni/version.py'
,
'w'
).
write
(
"__version__ = '999.dev0'"
)
if
not
self
.
skip_ts
:
if
not
self
.
skip_ts
:
setup_ts
.
build
(
release
=
None
)
setup_ts
.
build
(
release
=
None
)
super
().
run
()
super
().
run
()
...
...
setup_ts.py
View file @
349ead41
...
@@ -196,6 +196,7 @@ def copy_nni_node(version):
...
@@ -196,6 +196,7 @@ def copy_nni_node(version):
package_json
[
'version'
]
=
version
package_json
[
'version'
]
=
version
json
.
dump
(
package_json
,
open
(
'nni_node/package.json'
,
'w'
),
indent
=
2
)
json
.
dump
(
package_json
,
open
(
'nni_node/package.json'
,
'w'
),
indent
=
2
)
# reinstall without development dependencies
_yarn
(
'ts/nni_manager'
,
'--prod'
,
'--cwd'
,
str
(
Path
(
'nni_node'
).
resolve
()))
_yarn
(
'ts/nni_manager'
,
'--prod'
,
'--cwd'
,
str
(
Path
(
'nni_node'
).
resolve
()))
shutil
.
copytree
(
'ts/webui/build'
,
'nni_node/static'
)
shutil
.
copytree
(
'ts/webui/build'
,
'nni_node/static'
)
...
@@ -226,9 +227,9 @@ def _symlink(target_file, link_location):
...
@@ -226,9 +227,9 @@ def _symlink(target_file, link_location):
def
_print
(
*
args
):
def
_print
(
*
args
):
if
sys
.
platform
==
'win32'
:
if
sys
.
platform
==
'win32'
:
print
(
*
args
)
print
(
*
args
,
flush
=
True
)
else
:
else
:
print
(
'
\033
[1;36m#'
,
*
args
,
'
\033
[0m'
)
print
(
'
\033
[1;36m#'
,
*
args
,
'
\033
[0m'
,
flush
=
True
)
generated_files
=
[
generated_files
=
[
...
...
Prev
1
…
8
9
10
11
12
13
14
15
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment