Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
5136a86d
Unverified
Commit
5136a86d
authored
Mar 24, 2022
by
liuzhe-lz
Committed by
GitHub
Mar 24, 2022
Browse files
Typehint and copyright header (#4669)
parent
68347c5e
Changes
41
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
170 additions
and
60 deletions
+170
-60
nni/experiment/config/convert.py
nni/experiment/config/convert.py
+2
-0
nni/experiment/config/training_services/remote.py
nni/experiment/config/training_services/remote.py
+1
-1
nni/experiment/data.py
nni/experiment/data.py
+3
-0
nni/experiment/experiment.py
nni/experiment/experiment.py
+15
-15
nni/experiment/launcher.py
nni/experiment/launcher.py
+9
-7
nni/experiment/management.py
nni/experiment/management.py
+3
-0
nni/experiment/pipe.py
nni/experiment/pipe.py
+5
-3
nni/experiment/rest.py
nni/experiment/rest.py
+3
-0
nni/recoverable.py
nni/recoverable.py
+5
-3
nni/runtime/config.py
nni/runtime/config.py
+2
-2
nni/runtime/log.py
nni/runtime/log.py
+7
-2
nni/runtime/msg_dispatcher.py
nni/runtime/msg_dispatcher.py
+1
-0
nni/tools/jupyter_extension/__init__.py
nni/tools/jupyter_extension/__init__.py
+3
-0
nni/tools/jupyter_extension/management.py
nni/tools/jupyter_extension/management.py
+3
-0
nni/tools/jupyter_extension/proxy.py
nni/tools/jupyter_extension/proxy.py
+3
-0
nni/tools/nnictl/ts_management.py
nni/tools/nnictl/ts_management.py
+3
-0
nni/trial.py
nni/trial.py
+34
-13
nni/tuner.py
nni/tuner.py
+13
-10
nni/typehint.py
nni/typehint.py
+52
-4
pipelines/fast-test.yml
pipelines/fast-test.yml
+3
-0
No files found.
nni/experiment/config/convert.py
View file @
5136a86d
...
@@ -37,6 +37,8 @@ def to_v2(v1):
...
@@ -37,6 +37,8 @@ def to_v2(v1):
_move_field
(
v1_trial
,
v2
,
'command'
,
'trialCommand'
)
_move_field
(
v1_trial
,
v2
,
'command'
,
'trialCommand'
)
_move_field
(
v1_trial
,
v2
,
'codeDir'
,
'trialCodeDirectory'
)
_move_field
(
v1_trial
,
v2
,
'codeDir'
,
'trialCodeDirectory'
)
_move_field
(
v1_trial
,
v2
,
'gpuNum'
,
'trialGpuNumber'
)
_move_field
(
v1_trial
,
v2
,
'gpuNum'
,
'trialGpuNumber'
)
else
:
v1_trial
=
{}
for
algo_type
in
[
'tuner'
,
'assessor'
,
'advisor'
]:
for
algo_type
in
[
'tuner'
,
'assessor'
,
'advisor'
]:
v1_algo
=
v1
.
pop
(
algo_type
,
None
)
v1_algo
=
v1
.
pop
(
algo_type
,
None
)
...
...
nni/experiment/config/training_services/remote.py
View file @
5136a86d
...
@@ -53,7 +53,7 @@ class RemoteMachineConfig(ConfigBase):
...
@@ -53,7 +53,7 @@ class RemoteMachineConfig(ConfigBase):
if
self
.
password
is
not
None
:
if
self
.
password
is
not
None
:
warnings
.
warn
(
'SSH password will be exposed in web UI as plain text. We recommend to use SSH key file.'
)
warnings
.
warn
(
'SSH password will be exposed in web UI as plain text. We recommend to use SSH key file.'
)
elif
not
Path
(
self
.
ssh_key_file
).
is_file
():
elif
not
Path
(
self
.
ssh_key_file
).
is_file
():
# type: ignore
raise
ValueError
(
raise
ValueError
(
f
'RemoteMachineConfig: You must either provide password or a valid SSH key file "
{
self
.
ssh_key_file
}
"'
f
'RemoteMachineConfig: You must either provide password or a valid SSH key file "
{
self
.
ssh_key_file
}
"'
)
)
...
...
nni/experiment/data.py
View file @
5136a86d
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
import
json
import
json
from
typing
import
List
from
typing
import
List
...
...
nni/experiment/experiment.py
View file @
5136a86d
...
@@ -10,7 +10,7 @@ from pathlib import Path
...
@@ -10,7 +10,7 @@ from pathlib import Path
import
socket
import
socket
from
subprocess
import
Popen
from
subprocess
import
Popen
import
time
import
time
from
typing
import
Optional
,
Any
from
typing
import
Any
import
colorama
import
colorama
import
psutil
import
psutil
...
@@ -34,8 +34,7 @@ class RunMode(Enum):
...
@@ -34,8 +34,7 @@ class RunMode(Enum):
- Foreground: stop NNI manager when Python script exits; print NNI manager log to stdout.
- Foreground: stop NNI manager when Python script exits; print NNI manager log to stdout.
- Detach: do not stop NNI manager when Python script exits.
- Detach: do not stop NNI manager when Python script exits.
NOTE:
NOTE: This API is non-stable and is likely to get refactored in upcoming release.
This API is non-stable and is likely to get refactored in next release.
"""
"""
# TODO:
# TODO:
# NNI manager should treat log level more seriously so we can default to "foreground" without being too verbose.
# NNI manager should treat log level more seriously so we can default to "foreground" without being too verbose.
...
@@ -72,15 +71,15 @@ class Experiment:
...
@@ -72,15 +71,15 @@ class Experiment:
Web portal port. Or ``None`` if the experiment is not running.
Web portal port. Or ``None`` if the experiment is not running.
"""
"""
def
__init__
(
self
,
config_or_platform
:
ExperimentConfig
|
str
|
list
[
str
]
|
None
)
->
None
:
def
__init__
(
self
,
config_or_platform
:
ExperimentConfig
|
str
|
list
[
str
]
|
None
):
nni
.
runtime
.
log
.
init_logger_for_command_line
()
nni
.
runtime
.
log
.
init_logger_for_command_line
()
self
.
config
:
Optional
[
ExperimentConfig
]
=
None
self
.
config
:
ExperimentConfig
|
None
=
None
self
.
id
:
str
=
management
.
generate_experiment_id
()
self
.
id
:
str
=
management
.
generate_experiment_id
()
self
.
port
:
Optional
[
int
]
=
None
self
.
port
:
int
|
None
=
None
self
.
_proc
:
Optional
[
Popen
]
=
None
self
.
_proc
:
Popen
|
psutil
.
Process
|
None
=
None
self
.
action
=
'create'
self
.
_
action
=
'create'
self
.
url_prefix
:
Optional
[
str
]
=
None
self
.
url_prefix
:
str
|
None
=
None
if
isinstance
(
config_or_platform
,
(
str
,
list
)):
if
isinstance
(
config_or_platform
,
(
str
,
list
)):
self
.
config
=
ExperimentConfig
(
config_or_platform
)
self
.
config
=
ExperimentConfig
(
config_or_platform
)
...
@@ -101,6 +100,7 @@ class Experiment:
...
@@ -101,6 +100,7 @@ class Experiment:
debug
debug
Whether to start in debug mode.
Whether to start in debug mode.
"""
"""
assert
self
.
config
is
not
None
if
run_mode
is
not
RunMode
.
Detach
:
if
run_mode
is
not
RunMode
.
Detach
:
atexit
.
register
(
self
.
stop
)
atexit
.
register
(
self
.
stop
)
...
@@ -114,7 +114,7 @@ class Experiment:
...
@@ -114,7 +114,7 @@ class Experiment:
log_dir
=
Path
.
home
()
/
f
'nni-experiments/
{
self
.
id
}
/log'
log_dir
=
Path
.
home
()
/
f
'nni-experiments/
{
self
.
id
}
/log'
nni
.
runtime
.
log
.
start_experiment_log
(
self
.
id
,
log_dir
,
debug
)
nni
.
runtime
.
log
.
start_experiment_log
(
self
.
id
,
log_dir
,
debug
)
self
.
_proc
=
launcher
.
start_experiment
(
self
.
action
,
self
.
id
,
config
,
port
,
debug
,
run_mode
,
self
.
url_prefix
)
self
.
_proc
=
launcher
.
start_experiment
(
self
.
_
action
,
self
.
id
,
config
,
port
,
debug
,
run_mode
,
self
.
url_prefix
)
assert
self
.
_proc
is
not
None
assert
self
.
_proc
is
not
None
self
.
port
=
port
# port will be None if start up failed
self
.
port
=
port
# port will be None if start up failed
...
@@ -144,16 +144,16 @@ class Experiment:
...
@@ -144,16 +144,16 @@ class Experiment:
_logger
.
warning
(
'Cannot gracefully stop experiment, killing NNI process...'
)
_logger
.
warning
(
'Cannot gracefully stop experiment, killing NNI process...'
)
kill_command
(
self
.
_proc
.
pid
)
kill_command
(
self
.
_proc
.
pid
)
self
.
id
=
None
self
.
id
=
None
# type: ignore
self
.
port
=
None
self
.
port
=
None
self
.
_proc
=
None
self
.
_proc
=
None
_logger
.
info
(
'Experiment stopped'
)
_logger
.
info
(
'Experiment stopped'
)
def
run
(
self
,
port
:
int
=
8080
,
wait_completion
:
bool
=
True
,
debug
:
bool
=
False
)
->
bool
:
def
run
(
self
,
port
:
int
=
8080
,
wait_completion
:
bool
=
True
,
debug
:
bool
=
False
)
->
bool
|
None
:
"""
"""
Run the experiment.
Run the experiment.
If ``wait_completion`` is True, this function will block until experiment finish or error.
If ``wait_completion`` is
``
True
``
, this function will block until experiment finish or error.
Return ``True`` when experiment done; or return ``False`` when experiment failed.
Return ``True`` when experiment done; or return ``False`` when experiment failed.
...
@@ -247,7 +247,7 @@ class Experiment:
...
@@ -247,7 +247,7 @@ class Experiment:
def
_resume
(
exp_id
,
exp_dir
=
None
):
def
_resume
(
exp_id
,
exp_dir
=
None
):
exp
=
Experiment
(
None
)
exp
=
Experiment
(
None
)
exp
.
id
=
exp_id
exp
.
id
=
exp_id
exp
.
action
=
'resume'
exp
.
_
action
=
'resume'
exp
.
config
=
launcher
.
get_stopped_experiment_config
(
exp_id
,
exp_dir
)
exp
.
config
=
launcher
.
get_stopped_experiment_config
(
exp_id
,
exp_dir
)
return
exp
return
exp
...
@@ -255,7 +255,7 @@ class Experiment:
...
@@ -255,7 +255,7 @@ class Experiment:
def
_view
(
exp_id
,
exp_dir
=
None
):
def
_view
(
exp_id
,
exp_dir
=
None
):
exp
=
Experiment
(
None
)
exp
=
Experiment
(
None
)
exp
.
id
=
exp_id
exp
.
id
=
exp_id
exp
.
action
=
'view'
exp
.
_
action
=
'view'
exp
.
config
=
launcher
.
get_stopped_experiment_config
(
exp_id
,
exp_dir
)
exp
.
config
=
launcher
.
get_stopped_experiment_config
(
exp_id
,
exp_dir
)
return
exp
return
exp
...
...
nni/experiment/launcher.py
View file @
5136a86d
# Copyright (c) Microsoft Corporation.
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# Licensed under the MIT license.
from
__future__
import
annotations
import
contextlib
import
contextlib
from
dataclasses
import
dataclass
,
fields
from
dataclasses
import
dataclass
,
fields
from
datetime
import
datetime
from
datetime
import
datetime
...
@@ -126,9 +128,9 @@ def start_experiment(action, exp_id, config, port, debug, run_mode, url_prefix):
...
@@ -126,9 +128,9 @@ def start_experiment(action, exp_id, config, port, debug, run_mode, url_prefix):
return
proc
return
proc
def
_start_rest_server
(
nni_manager_args
,
run_mode
)
->
Tuple
[
int
,
Popen
]
:
def
_start_rest_server
(
nni_manager_args
,
run_mode
)
->
Popen
:
import
nni_node
import
nni_node
node_dir
=
Path
(
nni_node
.
__path__
[
0
])
node_dir
=
Path
(
nni_node
.
__path__
[
0
])
# type: ignore
node
=
str
(
node_dir
/
(
'node.exe'
if
sys
.
platform
==
'win32'
else
'node'
))
node
=
str
(
node_dir
/
(
'node.exe'
if
sys
.
platform
==
'win32'
else
'node'
))
main_js
=
str
(
node_dir
/
'main.js'
)
main_js
=
str
(
node_dir
/
'main.js'
)
cmd
=
[
node
,
'--max-old-space-size=4096'
,
main_js
]
cmd
=
[
node
,
'--max-old-space-size=4096'
,
main_js
]
...
@@ -151,10 +153,10 @@ def _start_rest_server(nni_manager_args, run_mode) -> Tuple[int, Popen]:
...
@@ -151,10 +153,10 @@ def _start_rest_server(nni_manager_args, run_mode) -> Tuple[int, Popen]:
from
subprocess
import
CREATE_NEW_PROCESS_GROUP
from
subprocess
import
CREATE_NEW_PROCESS_GROUP
return
Popen
(
cmd
,
stdout
=
out
,
stderr
=
err
,
cwd
=
node_dir
,
creationflags
=
CREATE_NEW_PROCESS_GROUP
)
return
Popen
(
cmd
,
stdout
=
out
,
stderr
=
err
,
cwd
=
node_dir
,
creationflags
=
CREATE_NEW_PROCESS_GROUP
)
else
:
else
:
return
Popen
(
cmd
,
stdout
=
out
,
stderr
=
err
,
cwd
=
node_dir
,
preexec_fn
=
os
.
setpgrp
)
return
Popen
(
cmd
,
stdout
=
out
,
stderr
=
err
,
cwd
=
node_dir
,
preexec_fn
=
os
.
setpgrp
)
# type: ignore
def
start_experiment_retiarii
(
exp_id
:
str
,
config
:
ExperimentConfig
,
port
:
int
,
debug
:
bool
)
->
Popen
:
def
start_experiment_retiarii
(
exp_id
,
config
,
port
,
debug
)
:
pipe
=
None
pipe
=
None
proc
=
None
proc
=
None
...
@@ -221,7 +223,7 @@ def _start_rest_server_retiarii(config: ExperimentConfig, port: int, debug: bool
...
@@ -221,7 +223,7 @@ def _start_rest_server_retiarii(config: ExperimentConfig, port: int, debug: bool
args
[
'dispatcher_pipe'
]
=
pipe_path
args
[
'dispatcher_pipe'
]
=
pipe_path
import
nni_node
import
nni_node
node_dir
=
Path
(
nni_node
.
__path__
[
0
])
node_dir
=
Path
(
nni_node
.
__path__
[
0
])
# type: ignore
node
=
str
(
node_dir
/
(
'node.exe'
if
sys
.
platform
==
'win32'
else
'node'
))
node
=
str
(
node_dir
/
(
'node.exe'
if
sys
.
platform
==
'win32'
else
'node'
))
main_js
=
str
(
node_dir
/
'main.js'
)
main_js
=
str
(
node_dir
/
'main.js'
)
cmd
=
[
node
,
'--max-old-space-size=4096'
,
main_js
]
cmd
=
[
node
,
'--max-old-space-size=4096'
,
main_js
]
...
@@ -259,8 +261,8 @@ def _save_experiment_information(experiment_id: str, port: int, start_time: int,
...
@@ -259,8 +261,8 @@ def _save_experiment_information(experiment_id: str, port: int, start_time: int,
def
get_stopped_experiment_config
(
exp_id
,
exp_dir
=
None
):
def
get_stopped_experiment_config
(
exp_id
,
exp_dir
=
None
):
config_json
=
get_stopped_experiment_config_json
(
exp_id
,
exp_dir
)
config_json
=
get_stopped_experiment_config_json
(
exp_id
,
exp_dir
)
# type: ignore
config
=
ExperimentConfig
(
**
config_json
)
config
=
ExperimentConfig
(
**
config_json
)
# type: ignore
if
exp_dir
and
not
os
.
path
.
samefile
(
exp_dir
,
config
.
experiment_working_directory
):
if
exp_dir
and
not
os
.
path
.
samefile
(
exp_dir
,
config
.
experiment_working_directory
):
msg
=
'Experiment working directory provided in command line (%s) is different from experiment config (%s)'
msg
=
'Experiment working directory provided in command line (%s) is different from experiment config (%s)'
_logger
.
warning
(
msg
,
exp_dir
,
config
.
experiment_working_directory
)
_logger
.
warning
(
msg
,
exp_dir
,
config
.
experiment_working_directory
)
...
...
nni/experiment/management.py
View file @
5136a86d
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from
pathlib
import
Path
from
pathlib
import
Path
import
random
import
random
import
string
import
string
...
...
nni/experiment/pipe.py
View file @
5136a86d
from
io
import
BufferedIOBase
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import
logging
import
logging
import
os
import
os
import
sys
import
sys
...
@@ -25,7 +27,7 @@ if sys.platform == 'win32':
...
@@ -25,7 +27,7 @@ if sys.platform == 'win32':
_winapi
.
NULL
_winapi
.
NULL
)
)
def
connect
(
self
)
->
BufferedIOBase
:
def
connect
(
self
):
_winapi
.
ConnectNamedPipe
(
self
.
_handle
,
_winapi
.
NULL
)
_winapi
.
ConnectNamedPipe
(
self
.
_handle
,
_winapi
.
NULL
)
fd
=
msvcrt
.
open_osfhandle
(
self
.
_handle
,
0
)
fd
=
msvcrt
.
open_osfhandle
(
self
.
_handle
,
0
)
self
.
file
=
os
.
fdopen
(
fd
,
'w+b'
)
self
.
file
=
os
.
fdopen
(
fd
,
'w+b'
)
...
@@ -55,7 +57,7 @@ else:
...
@@ -55,7 +57,7 @@ else:
self
.
_socket
.
bind
(
self
.
path
)
self
.
_socket
.
bind
(
self
.
path
)
self
.
_socket
.
listen
(
1
)
# only accepts one connection
self
.
_socket
.
listen
(
1
)
# only accepts one connection
def
connect
(
self
)
->
BufferedIOBase
:
def
connect
(
self
):
conn
,
_
=
self
.
_socket
.
accept
()
conn
,
_
=
self
.
_socket
.
accept
()
self
.
file
=
conn
.
makefile
(
'rwb'
)
self
.
file
=
conn
.
makefile
(
'rwb'
)
return
self
.
file
return
self
.
file
...
...
nni/experiment/rest.py
View file @
5136a86d
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import
logging
import
logging
from
typing
import
Any
,
Optional
from
typing
import
Any
,
Optional
...
...
nni/recoverable.py
View file @
5136a86d
# Copyright (c) Microsoft Corporation.
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# Licensed under the MIT license.
from
__future__
import
annotations
import
os
import
os
class
Recoverable
:
class
Recoverable
:
def
load_checkpoint
(
self
):
def
load_checkpoint
(
self
)
->
None
:
pass
pass
def
save_checkpoint
(
self
):
def
save_checkpoint
(
self
)
->
None
:
pass
pass
def
get_checkpoint_path
(
self
):
def
get_checkpoint_path
(
self
)
->
str
|
None
:
ckp_path
=
os
.
getenv
(
'NNI_CHECKPOINT_DIRECTORY'
)
ckp_path
=
os
.
getenv
(
'NNI_CHECKPOINT_DIRECTORY'
)
if
ckp_path
is
not
None
and
os
.
path
.
isdir
(
ckp_path
):
if
ckp_path
is
not
None
and
os
.
path
.
isdir
(
ckp_path
):
return
ckp_path
return
ckp_path
...
...
nni/runtime/config.py
View file @
5136a86d
...
@@ -14,7 +14,7 @@ def get_config_directory() -> Path:
...
@@ -14,7 +14,7 @@ def get_config_directory() -> Path:
Create it if not exist.
Create it if not exist.
"""
"""
if
os
.
getenv
(
'NNI_CONFIG_DIR'
)
is
not
None
:
if
os
.
getenv
(
'NNI_CONFIG_DIR'
)
is
not
None
:
config_dir
=
Path
(
os
.
getenv
(
'NNI_CONFIG_DIR'
))
config_dir
=
Path
(
os
.
getenv
(
'NNI_CONFIG_DIR'
))
# type: ignore
elif
sys
.
prefix
!=
sys
.
base_prefix
or
Path
(
sys
.
prefix
,
'conda-meta'
).
is_dir
():
elif
sys
.
prefix
!=
sys
.
base_prefix
or
Path
(
sys
.
prefix
,
'conda-meta'
).
is_dir
():
config_dir
=
Path
(
sys
.
prefix
,
'nni'
)
config_dir
=
Path
(
sys
.
prefix
,
'nni'
)
elif
sys
.
platform
==
'win32'
:
elif
sys
.
platform
==
'win32'
:
...
@@ -39,4 +39,4 @@ def get_builtin_config_file(name: str) -> Path:
...
@@ -39,4 +39,4 @@ def get_builtin_config_file(name: str) -> Path:
"""
"""
Get a readonly builtin config file.
Get a readonly builtin config file.
"""
"""
return
Path
(
nni
.
__path__
[
0
],
'runtime/default_config'
,
name
)
return
Path
(
nni
.
__path__
[
0
],
'runtime/default_config'
,
name
)
# type: ignore
nni/runtime/log.py
View file @
5136a86d
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from
__future__
import
annotations
import
logging
import
logging
import
sys
import
sys
from
datetime
import
datetime
from
datetime
import
datetime
...
@@ -105,7 +110,7 @@ def _init_logger_standalone() -> None:
...
@@ -105,7 +110,7 @@ def _init_logger_standalone() -> None:
_register_handler
(
StreamHandler
(
sys
.
stdout
),
logging
.
INFO
)
_register_handler
(
StreamHandler
(
sys
.
stdout
),
logging
.
INFO
)
def
_prepare_log_dir
(
path
:
Optional
[
str
]
)
->
Path
:
def
_prepare_log_dir
(
path
:
Path
|
str
)
->
Path
:
if
path
is
None
:
if
path
is
None
:
return
Path
()
return
Path
()
ret
=
Path
(
path
)
ret
=
Path
(
path
)
...
@@ -148,7 +153,7 @@ class _LogFileWrapper(TextIOBase):
...
@@ -148,7 +153,7 @@ class _LogFileWrapper(TextIOBase):
def
__init__
(
self
,
log_file
:
TextIOBase
):
def
__init__
(
self
,
log_file
:
TextIOBase
):
self
.
file
:
TextIOBase
=
log_file
self
.
file
:
TextIOBase
=
log_file
self
.
line_buffer
:
Optional
[
str
]
=
None
self
.
line_buffer
:
Optional
[
str
]
=
None
self
.
line_start_time
:
Optional
[
datetime
]
=
None
self
.
line_start_time
:
datetime
=
datetime
.
fromtimestamp
(
0
)
def
write
(
self
,
s
:
str
)
->
int
:
def
write
(
self
,
s
:
str
)
->
int
:
cur_time
=
datetime
.
now
()
cur_time
=
datetime
.
now
()
...
...
nni/runtime/msg_dispatcher.py
View file @
5136a86d
...
@@ -212,6 +212,7 @@ class MsgDispatcher(MsgDispatcherBase):
...
@@ -212,6 +212,7 @@ class MsgDispatcher(MsgDispatcherBase):
except
Exception
as
e
:
except
Exception
as
e
:
_logger
.
error
(
'Assessor error'
)
_logger
.
error
(
'Assessor error'
)
_logger
.
exception
(
e
)
_logger
.
exception
(
e
)
raise
if
isinstance
(
result
,
bool
):
if
isinstance
(
result
,
bool
):
result
=
AssessResult
.
Good
if
result
else
AssessResult
.
Bad
result
=
AssessResult
.
Good
if
result
else
AssessResult
.
Bad
...
...
nni/tools/jupyter_extension/__init__.py
View file @
5136a86d
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from
.
import
proxy
from
.
import
proxy
load_jupyter_server_extension
=
proxy
.
setup
load_jupyter_server_extension
=
proxy
.
setup
...
...
nni/tools/jupyter_extension/management.py
View file @
5136a86d
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import
json
import
json
from
pathlib
import
Path
from
pathlib
import
Path
import
shutil
import
shutil
...
...
nni/tools/jupyter_extension/proxy.py
View file @
5136a86d
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import
json
import
json
from
pathlib
import
Path
from
pathlib
import
Path
...
...
nni/tools/nnictl/ts_management.py
View file @
5136a86d
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import
importlib
import
importlib
import
json
import
json
...
...
nni/trial.py
View file @
5136a86d
# Copyright (c) Microsoft Corporation.
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# Licensed under the MIT license.
from
__future__
import
annotations
from
typing
import
Any
from
.common.serializer
import
dump
from
.common.serializer
import
dump
from
.runtime.env_vars
import
trial_env_vars
from
.runtime.env_vars
import
trial_env_vars
from
.runtime
import
platform
from
.runtime
import
platform
from
.typehint
import
Parameters
,
TrialMetric
__all__
=
[
__all__
=
[
'get_next_parameter'
,
'get_next_parameter'
,
'get_next_parameters'
,
'get_current_parameter'
,
'get_current_parameter'
,
'report_intermediate_result'
,
'report_intermediate_result'
,
'report_final_result'
,
'report_final_result'
,
...
@@ -23,7 +28,7 @@ _trial_id = platform.get_trial_id()
...
@@ -23,7 +28,7 @@ _trial_id = platform.get_trial_id()
_sequence_id
=
platform
.
get_sequence_id
()
_sequence_id
=
platform
.
get_sequence_id
()
def
get_next_parameter
():
def
get_next_parameter
()
->
Parameters
:
"""
"""
Get the hyperparameters generated by tuner.
Get the hyperparameters generated by tuner.
...
@@ -32,7 +37,7 @@ def get_next_parameter():
...
@@ -32,7 +37,7 @@ def get_next_parameter():
Examples
Examples
--------
--------
Assuming the search space is:
Assuming the
:doc:`
search space
</hpo/search_space>`
is:
.. code-block::
.. code-block::
...
@@ -52,16 +57,22 @@ def get_next_parameter():
...
@@ -52,16 +57,22 @@ def get_next_parameter():
Returns
Returns
-------
-------
dict
:class:`~nni.typehint.Parameters`
A hyperparameter set sampled from search space.
A hyperparameter set sampled from search space.
"""
"""
global
_params
global
_params
_params
=
platform
.
get_next_parameter
()
_params
=
platform
.
get_next_parameter
()
if
_params
is
None
:
if
_params
is
None
:
return
None
return
None
# type: ignore
return
_params
[
'parameters'
]
return
_params
[
'parameters'
]
def
get_current_parameter
(
tag
=
None
):
def
get_next_parameters
()
->
Parameters
:
"""
Alias of :func:`get_next_parameter`
"""
return
get_next_parameter
()
def
get_current_parameter
(
tag
:
str
|
None
=
None
)
->
Any
:
global
_params
global
_params
if
_params
is
None
:
if
_params
is
None
:
return
None
return
None
...
@@ -94,13 +105,13 @@ def get_sequence_id() -> int:
...
@@ -94,13 +105,13 @@ def get_sequence_id() -> int:
_intermediate_seq
=
0
_intermediate_seq
=
0
def
overwrite_intermediate_seq
(
value
)
:
def
overwrite_intermediate_seq
(
value
:
int
)
->
None
:
assert
isinstance
(
value
,
int
)
assert
isinstance
(
value
,
int
)
global
_intermediate_seq
global
_intermediate_seq
_intermediate_seq
=
value
_intermediate_seq
=
value
def
report_intermediate_result
(
metric
)
:
def
report_intermediate_result
(
metric
:
TrialMetric
|
dict
[
str
,
Any
])
->
None
:
"""
"""
Reports intermediate result to NNI.
Reports intermediate result to NNI.
...
@@ -110,11 +121,16 @@ def report_intermediate_result(metric):
...
@@ -110,11 +121,16 @@ def report_intermediate_result(metric):
and other items can be visualized with web portal.
and other items can be visualized with web portal.
Typically ``metric`` is per-epoch accuracy or loss.
Typically ``metric`` is per-epoch accuracy or loss.
Parameters
----------
metric : :class:`~nni.typehint.TrialMetric`
The intermeidate result.
"""
"""
global
_intermediate_seq
global
_intermediate_seq
assert
_params
or
trial_env_vars
.
NNI_PLATFORM
is
None
,
\
assert
_params
or
trial_env_vars
.
NNI_PLATFORM
is
None
,
\
'nni.get_next_parameter() needs to be called before report_intermediate_result'
'nni.get_next_parameter() needs to be called before report_intermediate_result'
metric
=
dump
({
dumped_
metric
=
dump
({
'parameter_id'
:
_params
[
'parameter_id'
]
if
_params
else
None
,
'parameter_id'
:
_params
[
'parameter_id'
]
if
_params
else
None
,
'trial_job_id'
:
trial_env_vars
.
NNI_TRIAL_JOB_ID
,
'trial_job_id'
:
trial_env_vars
.
NNI_TRIAL_JOB_ID
,
'type'
:
'PERIODICAL'
,
'type'
:
'PERIODICAL'
,
...
@@ -122,9 +138,9 @@ def report_intermediate_result(metric):
...
@@ -122,9 +138,9 @@ def report_intermediate_result(metric):
'value'
:
dump
(
metric
)
'value'
:
dump
(
metric
)
})
})
_intermediate_seq
+=
1
_intermediate_seq
+=
1
platform
.
send_metric
(
metric
)
platform
.
send_metric
(
dumped_
metric
)
def
report_final_result
(
metric
)
:
def
report_final_result
(
metric
:
TrialMetric
|
dict
[
str
,
Any
])
->
None
:
"""
"""
Reports final result to NNI.
Reports final result to NNI.
...
@@ -134,14 +150,19 @@ def report_final_result(metric):
...
@@ -134,14 +150,19 @@ def report_final_result(metric):
and other items can be visualized with web portal.
and other items can be visualized with web portal.
Typically ``metric`` is the final accuracy or loss.
Typically ``metric`` is the final accuracy or loss.
Parameters
----------
metric : :class:`~nni.typehint.TrialMetric`
The final result.
"""
"""
assert
_params
or
trial_env_vars
.
NNI_PLATFORM
is
None
,
\
assert
_params
or
trial_env_vars
.
NNI_PLATFORM
is
None
,
\
'nni.get_next_parameter() needs to be called before report_final_result'
'nni.get_next_parameter() needs to be called before report_final_result'
metric
=
dump
({
dumped_
metric
=
dump
({
'parameter_id'
:
_params
[
'parameter_id'
]
if
_params
else
None
,
'parameter_id'
:
_params
[
'parameter_id'
]
if
_params
else
None
,
'trial_job_id'
:
trial_env_vars
.
NNI_TRIAL_JOB_ID
,
'trial_job_id'
:
trial_env_vars
.
NNI_TRIAL_JOB_ID
,
'type'
:
'FINAL'
,
'type'
:
'FINAL'
,
'sequence'
:
0
,
'sequence'
:
0
,
'value'
:
dump
(
metric
)
'value'
:
dump
(
metric
)
})
})
platform
.
send_metric
(
metric
)
platform
.
send_metric
(
dumped_
metric
)
nni/tuner.py
View file @
5136a86d
...
@@ -8,11 +8,14 @@ A new trial will run with this configuration.
...
@@ -8,11 +8,14 @@ A new trial will run with this configuration.
See :class:`Tuner`' specification and ``docs/en_US/tuners.rst`` for details.
See :class:`Tuner`' specification and ``docs/en_US/tuners.rst`` for details.
"""
"""
from
__future__
import
annotations
import
logging
import
logging
import
nni
import
nni
from
.recoverable
import
Recoverable
from
.recoverable
import
Recoverable
from
.typehint
import
Parameters
,
SearchSpace
,
TrialMetric
,
TrialRecord
__all__
=
[
'Tuner'
]
__all__
=
[
'Tuner'
]
...
@@ -67,7 +70,7 @@ class Tuner(Recoverable):
...
@@ -67,7 +70,7 @@ class Tuner(Recoverable):
:class:`~nni.algorithms.hpo.gp_tuner.gp_tuner.GPTuner`
:class:`~nni.algorithms.hpo.gp_tuner.gp_tuner.GPTuner`
"""
"""
def
generate_parameters
(
self
,
parameter_id
,
**
kwargs
):
def
generate_parameters
(
self
,
parameter_id
:
int
,
**
kwargs
)
->
Parameters
:
"""
"""
Abstract method which provides a set of hyper-parameters.
Abstract method which provides a set of hyper-parameters.
...
@@ -100,7 +103,7 @@ class Tuner(Recoverable):
...
@@ -100,7 +103,7 @@ class Tuner(Recoverable):
# we need to design a new exception for this purpose
# we need to design a new exception for this purpose
raise
NotImplementedError
(
'Tuner: generate_parameters not implemented'
)
raise
NotImplementedError
(
'Tuner: generate_parameters not implemented'
)
def
generate_multiple_parameters
(
self
,
parameter_id_list
,
**
kwargs
)
:
def
generate_multiple_parameters
(
self
,
parameter_id_list
:
list
[
int
],
**
kwargs
)
->
list
[
Parameters
]
:
"""
"""
Callback method which provides multiple sets of hyper-parameters.
Callback method which provides multiple sets of hyper-parameters.
...
@@ -135,7 +138,7 @@ class Tuner(Recoverable):
...
@@ -135,7 +138,7 @@ class Tuner(Recoverable):
result
.
append
(
res
)
result
.
append
(
res
)
return
result
return
result
def
receive_trial_result
(
self
,
parameter_id
,
parameters
,
value
,
**
kwargs
):
def
receive_trial_result
(
self
,
parameter_id
:
int
,
parameters
:
Parameters
,
value
:
TrialMetric
,
**
kwargs
)
->
None
:
"""
"""
Abstract method invoked when a trial reports its final result. Must override.
Abstract method invoked when a trial reports its final result. Must override.
...
@@ -165,7 +168,7 @@ class Tuner(Recoverable):
...
@@ -165,7 +168,7 @@ class Tuner(Recoverable):
# pylint: disable=attribute-defined-outside-init
# pylint: disable=attribute-defined-outside-init
self
.
_accept_customized
=
accept
self
.
_accept_customized
=
accept
def
trial_end
(
self
,
parameter_id
,
success
,
**
kwargs
):
def
trial_end
(
self
,
parameter_id
:
int
,
success
:
bool
,
**
kwargs
)
->
None
:
"""
"""
Abstract method invoked when a trial is completed or terminated. Do nothing by default.
Abstract method invoked when a trial is completed or terminated. Do nothing by default.
...
@@ -179,7 +182,7 @@ class Tuner(Recoverable):
...
@@ -179,7 +182,7 @@ class Tuner(Recoverable):
Unstable parameters which should be ignored by normal users.
Unstable parameters which should be ignored by normal users.
"""
"""
def
update_search_space
(
self
,
search_space
)
:
def
update_search_space
(
self
,
search_space
:
SearchSpace
)
->
None
:
"""
"""
Abstract method for updating the search space. Must override.
Abstract method for updating the search space. Must override.
...
@@ -194,21 +197,21 @@ class Tuner(Recoverable):
...
@@ -194,21 +197,21 @@ class Tuner(Recoverable):
"""
"""
raise
NotImplementedError
(
'Tuner: update_search_space not implemented'
)
raise
NotImplementedError
(
'Tuner: update_search_space not implemented'
)
def
load_checkpoint
(
self
):
def
load_checkpoint
(
self
)
->
None
:
"""
"""
Internal API under revising, not recommended for end users.
Internal API under revising, not recommended for end users.
"""
"""
checkpoin_path
=
self
.
get_checkpoint_path
()
checkpoin_path
=
self
.
get_checkpoint_path
()
_logger
.
info
(
'Load checkpoint ignored by tuner, checkpoint path: %s'
,
checkpoin_path
)
_logger
.
info
(
'Load checkpoint ignored by tuner, checkpoint path: %s'
,
checkpoin_path
)
def
save_checkpoint
(
self
):
def
save_checkpoint
(
self
)
->
None
:
"""
"""
Internal API under revising, not recommended for end users.
Internal API under revising, not recommended for end users.
"""
"""
checkpoin_path
=
self
.
get_checkpoint_path
()
checkpoin_path
=
self
.
get_checkpoint_path
()
_logger
.
info
(
'Save checkpoint ignored by tuner, checkpoint path: %s'
,
checkpoin_path
)
_logger
.
info
(
'Save checkpoint ignored by tuner, checkpoint path: %s'
,
checkpoin_path
)
def
import_data
(
self
,
data
)
:
def
import_data
(
self
,
data
:
list
[
TrialRecord
])
->
None
:
"""
"""
Internal API under revising, not recommended for end users.
Internal API under revising, not recommended for end users.
"""
"""
...
@@ -216,8 +219,8 @@ class Tuner(Recoverable):
...
@@ -216,8 +219,8 @@ class Tuner(Recoverable):
# data: a list of dictionarys, each of which has at least two keys, 'parameter' and 'value'
# data: a list of dictionarys, each of which has at least two keys, 'parameter' and 'value'
pass
pass
def
_on_exit
(
self
):
def
_on_exit
(
self
)
->
None
:
pass
pass
def
_on_error
(
self
):
def
_on_error
(
self
)
->
None
:
pass
pass
nni/typehint.py
View file @
5136a86d
# Copyright (c) Microsoft Corporation.
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# Licensed under the MIT license.
"""
Types for static checking.
"""
__all__
=
[
'Literal'
,
'Parameters'
,
'SearchSpace'
,
'TrialMetric'
,
'TrialRecord'
,
]
import
sys
import
sys
import
typing
from
typing
import
Any
,
Dict
,
List
,
TYPE_CHECKING
if
typing
.
TYPE_CHECKING
or
sys
.
version_info
>=
(
3
,
8
):
if
TYPE_CHECKING
or
sys
.
version_info
>=
(
3
,
8
):
Literal
=
typing
.
Literal
from
typing
import
Literal
,
TypedDict
else
:
else
:
Literal
=
typing
.
Any
from
typing_extensions
import
Literal
,
TypedDict
Parameters
=
Dict
[
str
,
Any
]
"""
Return type of :func:`nni.get_next_parameter`.
For built-in tuners, this is a ``dict`` whose content is defined by :doc:`search space </hpo/search_space>`.
Customized tuners do not need to follow the constraint and can use anything serializable.
"""
class
_ParameterSearchSpace
(
TypedDict
):
_type
:
Literal
[
'choice'
,
'randint'
,
'uniform'
,
'loguniform'
,
'quniform'
,
'qloguniform'
,
'normal'
,
'lognormal'
,
'qnormal'
,
'qlognormal'
,
]
_value
:
List
[
Any
]
SearchSpace
=
Dict
[
str
,
_ParameterSearchSpace
]
"""
Type of ``experiment.config.search_space``.
For built-in tuners, the format is detailed in :doc:`/hpo/search_space`.
Customized tuners do not need to follow the constraint and can use anything serializable, except ``None``.
"""
TrialMetric
=
float
"""
Type of the metrics sent to :func:`nni.report_final_result` and :func:`nni.report_intermediate_result`.
For built-in tuners it must be a number (``float``, ``int``, ``numpy.float32``, etc).
Customized tuners do not need to follow this constraint and can use anything serializable.
"""
class
TrialRecord
(
TypedDict
):
parameter
:
Parameters
value
:
TrialMetric
pipelines/fast-test.yml
View file @
5136a86d
...
@@ -63,6 +63,9 @@ stages:
...
@@ -63,6 +63,9 @@ stages:
python -m flake8 examples --count --exclude=$EXCLUDES --select=E9,F63,F72,F82 --show-source --statistics
python -m flake8 examples --count --exclude=$EXCLUDES --select=E9,F63,F72,F82 --show-source --statistics
displayName
:
flake8
displayName
:
flake8
-
script
:
|
python -m pyright nni
-
job
:
typescript
-
job
:
typescript
pool
:
pool
:
vmImage
:
ubuntu-latest
vmImage
:
ubuntu-latest
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment