Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
d5857823
Unverified
Commit
d5857823
authored
Dec 20, 2021
by
liuzhe-lz
Committed by
GitHub
Dec 20, 2021
Browse files
Config refactor (#4370)
parent
cb090e8c
Changes
70
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1422 additions
and
807 deletions
+1422
-807
nni/experiment/config/training_services/kubeflow.py
nni/experiment/config/training_services/kubeflow.py
+49
-0
nni/experiment/config/training_services/local.py
nni/experiment/config/training_services/local.py
+47
-0
nni/experiment/config/training_services/openpai.py
nni/experiment/config/training_services/openpai.py
+60
-0
nni/experiment/config/training_services/remote.py
nni/experiment/config/training_services/remote.py
+72
-0
nni/experiment/config/util.py
nni/experiment/config/util.py
+0
-101
nni/experiment/config/utils/__init__.py
nni/experiment/config/utils/__init__.py
+11
-0
nni/experiment/config/utils/internal.py
nni/experiment/config/utils/internal.py
+174
-0
nni/experiment/config/utils/public.py
nni/experiment/config/utils/public.py
+68
-0
nni/experiment/experiment.py
nni/experiment/experiment.py
+60
-40
nni/experiment/launcher.py
nni/experiment/launcher.py
+134
-36
nni/experiment/rest.py
nni/experiment/rest.py
+20
-11
nni/retiarii/experiment/pytorch.py
nni/retiarii/experiment/pytorch.py
+7
-10
nni/runtime/log.py
nni/runtime/log.py
+3
-1
nni/tools/nnictl/launcher.py
nni/tools/nnictl/launcher.py
+64
-597
nni/tools/nnictl/legacy_launcher.py
nni/tools/nnictl/legacy_launcher.py
+619
-0
nni/tools/nnictl/nnictl.py
nni/tools/nnictl/nnictl.py
+2
-2
test/config/integration_tests.yml
test/config/integration_tests.yml
+0
-3
test/config/integration_tests_tf2.yml
test/config/integration_tests_tf2.yml
+0
-3
test/config/pr_tests.yml
test/config/pr_tests.yml
+0
-3
test/ut/experiment/assets/config.yaml
test/ut/experiment/assets/config.yaml
+32
-0
No files found.
nni/experiment/config/kubeflow.py
→
nni/experiment/config/
training_services/
kubeflow.py
View file @
d5857823
# Copyright (c) Microsoft Corporation.
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# Licensed under the MIT license.
from
dataclasses
import
dataclass
"""
from
typing
import
Optional
Configuration for Kubeflow training service.
from
.base
import
ConfigBase
Check the reference_ for explaination of each field.
from
.common
import
TrainingServiceConfig
from
.
import
util
__all__
=
[
'KubeflowConfig'
,
'KubeflowRoleConfig'
,
'KubeflowStorageConfig'
,
'KubeflowNfsConfig'
,
'KubeflowAzureStorageConfig'
]
You may also want to check `Kubeflow training service doc`_.
.. _reference: https://nni.readthedocs.io/en/stable/reference/experiment_config.html
@
dataclass
(
init
=
False
)
.. _Kubeflow training service doc: https://nni.readthedocs.io/en/stable/TrainingService/KubeflowMode.html
class
KubeflowStorageConfig
(
ConfigBase
):
storage_type
:
str
server
:
Optional
[
str
]
=
None
path
:
Optional
[
str
]
=
None
azure_account
:
Optional
[
str
]
=
None
azure_share
:
Optional
[
str
]
=
None
key_vault_name
:
Optional
[
str
]
=
None
key_vault_key
:
Optional
[
str
]
=
None
@
dataclass
(
init
=
False
)
"""
class
KubeflowNfsConfig
(
KubeflowStorageConfig
):
storage
:
str
=
'nfs'
server
:
str
path
:
str
@
dataclass
(
init
=
False
)
__all__
=
[
'KubeflowConfig'
,
'KubeflowRoleConfig'
]
class
KubeflowAzureStorageConfig
(
ConfigBase
):
storage
:
str
=
'azureStorage'
azure_account
:
str
azure_share
:
str
key_vault_name
:
str
key_vault_key
:
str
from
dataclasses
import
dataclass
from
typing
import
Optional
,
Union
from
..base
import
ConfigBase
from
..training_service
import
TrainingServiceConfig
from
.k8s_storage
import
K8sStorageConfig
@
dataclass
(
init
=
False
)
@
dataclass
(
init
=
False
)
class
KubeflowRoleConfig
(
ConfigBase
):
class
KubeflowRoleConfig
(
ConfigBase
):
...
@@ -42,31 +29,21 @@ class KubeflowRoleConfig(ConfigBase):
...
@@ -42,31 +29,21 @@ class KubeflowRoleConfig(ConfigBase):
command
:
str
command
:
str
gpu_number
:
Optional
[
int
]
=
0
gpu_number
:
Optional
[
int
]
=
0
cpu_number
:
int
cpu_number
:
int
memory_size
:
str
memory_size
:
Union
[
str
,
int
]
docker_image
:
str
=
'msranni/nni:latest'
docker_image
:
str
=
'msranni/nni:latest'
code_directory
:
str
code_directory
:
str
@
dataclass
(
init
=
False
)
@
dataclass
(
init
=
False
)
class
KubeflowConfig
(
TrainingServiceConfig
):
class
KubeflowConfig
(
TrainingServiceConfig
):
platform
:
str
=
'kubeflow'
platform
:
str
=
'kubeflow'
operator
:
str
operator
:
str
api_version
:
str
api_version
:
str
storage
:
K
ubeflow
StorageConfig
storage
:
K
8s
StorageConfig
worker
:
Optional
[
KubeflowRoleConfig
]
=
None
worker
:
Optional
[
KubeflowRoleConfig
]
=
None
ps
:
Optional
[
KubeflowRoleConfig
]
=
None
ps
:
Optional
[
KubeflowRoleConfig
]
=
None
master
:
Optional
[
KubeflowRoleConfig
]
=
None
master
:
Optional
[
KubeflowRoleConfig
]
=
None
reuse_mode
:
Optional
[
bool
]
=
True
#set reuse mode as true for v2 config
reuse_mode
:
Optional
[
bool
]
=
True
#set reuse mode as true for v2 config
def
__init__
(
self
,
**
kwargs
):
def
_validate_canonical
(
self
):
kwargs
=
util
.
case_insensitive
(
kwargs
)
super
().
_validate_canonical
()
kwargs
[
'storage'
]
=
util
.
load_config
(
KubeflowStorageConfig
,
kwargs
.
get
(
'storage'
))
assert
self
.
operator
in
[
'tf-operator'
,
'pytorch-operator'
]
kwargs
[
'worker'
]
=
util
.
load_config
(
KubeflowRoleConfig
,
kwargs
.
get
(
'worker'
))
kwargs
[
'ps'
]
=
util
.
load_config
(
KubeflowRoleConfig
,
kwargs
.
get
(
'ps'
))
kwargs
[
'master'
]
=
util
.
load_config
(
KubeflowRoleConfig
,
kwargs
.
get
(
'master'
))
super
().
__init__
(
**
kwargs
)
_validation_rules
=
{
'platform'
:
lambda
value
:
(
value
==
'kubeflow'
,
'cannot be modified'
),
'operator'
:
lambda
value
:
value
in
[
'tf-operator'
,
'pytorch-operator'
]
}
\ No newline at end of file
nni/experiment/config/training_services/local.py
0 → 100644
View file @
d5857823
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
"""
Configuration for local training service.
Check the reference_ for explaination of each field.
You may also want to check `local training service doc`_.
.. _reference: https://nni.readthedocs.io/en/stable/reference/experiment_config.html
.. _local training service doc: https://nni.readthedocs.io/en/stable/TrainingService/LocalMode.html
"""
__all__
=
[
'LocalConfig'
]
from
dataclasses
import
dataclass
from
typing
import
List
,
Optional
,
Union
from
..training_service
import
TrainingServiceConfig
from
..
import
utils
@
dataclass
(
init
=
False
)
class
LocalConfig
(
TrainingServiceConfig
):
platform
:
str
=
'local'
use_active_gpu
:
Optional
[
bool
]
=
None
max_trial_number_per_gpu
:
int
=
1
gpu_indices
:
Union
[
List
[
int
],
int
,
str
,
None
]
=
None
reuse_mode
:
bool
=
False
def
_canonicalize
(
self
,
parents
):
super
().
_canonicalize
(
parents
)
self
.
gpu_indices
=
utils
.
canonical_gpu_indices
(
self
.
gpu_indices
)
self
.
nni_manager_ip
=
None
def
_validate_canonical
(
self
):
super
().
_validate_canonical
()
utils
.
validate_gpu_indices
(
self
.
gpu_indices
)
if
self
.
trial_gpu_number
and
self
.
use_active_gpu
is
None
:
raise
ValueError
(
'LocalConfig: please set use_active_gpu to True if your system has GUI, '
'or set it to False if the computer runs multiple experiments concurrently.'
)
if
not
self
.
trial_gpu_number
and
self
.
max_trial_number_per_gpu
!=
1
:
raise
ValueError
(
'LocalConfig: max_trial_number_per_gpu does not work without trial_gpu_number'
)
nni/experiment/config/openpai.py
→
nni/experiment/config/
training_services/
openpai.py
View file @
d5857823
# Copyright (c) Microsoft Corporation.
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# Licensed under the MIT license.
from
dataclasses
import
dataclass
"""
from
pathlib
import
Path
,
PurePosixPath
Configuration for OpenPAI training service.
from
typing
import
Any
,
Dict
,
Optional
Check the reference_ for explaination of each field.
You may also want to check `OpenPAI training service doc`_.
.. _reference: https://nni.readthedocs.io/en/stable/reference/experiment_config.html
from
.base
import
PathLike
.. _OpenPAI training service doc: https://nni.readthedocs.io/en/stable/TrainingService/PaiMode.html
from
.common
import
TrainingServiceConfig
from
.
import
util
"""
__all__
=
[
'OpenpaiConfig'
]
__all__
=
[
'OpenpaiConfig'
]
from
dataclasses
import
dataclass
from
pathlib
import
Path
from
typing
import
Dict
,
Optional
,
Union
from
..training_service
import
TrainingServiceConfig
from
..utils
import
PathLike
@
dataclass
(
init
=
False
)
@
dataclass
(
init
=
False
)
class
OpenpaiConfig
(
TrainingServiceConfig
):
class
OpenpaiConfig
(
TrainingServiceConfig
):
platform
:
str
=
'openpai'
platform
:
str
=
'openpai'
...
@@ -18,7 +30,7 @@ class OpenpaiConfig(TrainingServiceConfig):
...
@@ -18,7 +30,7 @@ class OpenpaiConfig(TrainingServiceConfig):
username
:
str
username
:
str
token
:
str
token
:
str
trial_cpu_number
:
int
trial_cpu_number
:
int
trial_memory_size
:
str
trial_memory_size
:
Union
[
str
,
int
]
storage_config_name
:
str
storage_config_name
:
str
docker_image
:
str
=
'msranni/nni:latest'
docker_image
:
str
=
'msranni/nni:latest'
virtual_cluster
:
Optional
[
str
]
virtual_cluster
:
Optional
[
str
]
...
@@ -26,23 +38,23 @@ class OpenpaiConfig(TrainingServiceConfig):
...
@@ -26,23 +38,23 @@ class OpenpaiConfig(TrainingServiceConfig):
container_storage_mount_point
:
str
container_storage_mount_point
:
str
reuse_mode
:
bool
=
True
reuse_mode
:
bool
=
True
openpai_config
:
Optional
[
Dict
[
str
,
Any
]
]
=
None
openpai_config
:
Optional
[
Dict
]
=
None
openpai_config_file
:
Optional
[
PathLike
]
=
None
openpai_config_file
:
Optional
[
PathLike
]
=
None
_canonical_rules
=
{
def
_canonicalize
(
self
,
parents
):
'host'
:
lambda
value
:
'https://'
+
value
if
'://'
not
in
value
else
value
,
# type: ignore
super
().
_canonicalize
(
parents
)
'local_storage_mount_point'
:
util
.
canonical_path
,
if
'://'
not
in
self
.
host
:
'openpai_config_file'
:
util
.
canonical_path
self
.
host
=
'https://'
+
self
.
host
}
def
_validate_canonical
(
self
)
->
None
:
_validation_rules
=
{
super
().
_validate_canonical
()
'platform'
:
lambda
value
:
(
value
==
'openpai'
,
'cannot be modified'
),
if
self
.
trial_gpu_number
is
None
:
'local_storage_mount_point'
:
lambda
value
:
Path
(
value
).
is_dir
(),
raise
ValueError
(
'OpenpaiConfig: trial_gpu_number is not set'
)
'container_storage_mount_point'
:
lambda
value
:
(
PurePosixPath
(
value
).
is_absolute
(),
'is not absolute'
),
if
not
Path
(
self
.
local_storage_mount_point
).
is_dir
():
'openpai_config_file'
:
lambda
value
:
Path
(
value
).
is_file
()
raise
ValueError
(
}
f
'OpenpaiConfig: local_storage_mount_point "(self.local_storage_mount_point)" is not a directory'
)
def
validate
(
self
)
->
None
:
super
().
validate
()
if
self
.
openpai_config
is
not
None
and
self
.
openpai_config_file
is
not
None
:
if
self
.
openpai_config
is
not
None
and
self
.
openpai_config_file
is
not
None
:
raise
ValueError
(
'openpai_config and openpai_config_file can only be set one'
)
raise
ValueError
(
'openpai_config and openpai_config_file can only be set one'
)
if
self
.
openpai_config_file
is
not
None
and
not
Path
(
self
.
openpai_config_file
).
is_file
():
raise
ValueError
(
f
'OpenpaiConfig: openpai_config_file "(self.openpai_config_file)" is not a file'
)
nni/experiment/config/training_services/remote.py
0 → 100644
View file @
d5857823
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
"""
Configuration for remote training service.
Check the reference_ for explaination of each field.
You may also want to check `remote training service doc`_.
.. _reference: https://nni.readthedocs.io/en/stable/reference/experiment_config.html
.. _remote training service doc: https://nni.readthedocs.io/en/stable/TrainingService/RemoteMachineMode.html
"""
__all__
=
[
'RemoteConfig'
,
'RemoteMachineConfig'
]
from
dataclasses
import
dataclass
from
pathlib
import
Path
from
typing
import
List
,
Optional
,
Union
import
warnings
from
..base
import
ConfigBase
from
..training_service
import
TrainingServiceConfig
from
..
import
utils
@
dataclass
(
init
=
False
)
class
RemoteMachineConfig
(
ConfigBase
):
host
:
str
port
:
int
=
22
user
:
str
password
:
Optional
[
str
]
=
None
ssh_key_file
:
Optional
[
utils
.
PathLike
]
=
'~/.ssh/id_rsa'
ssh_passphrase
:
Optional
[
str
]
=
None
use_active_gpu
:
bool
=
False
max_trial_number_per_gpu
:
int
=
1
gpu_indices
:
Union
[
List
[
int
],
int
,
str
,
None
]
=
None
python_path
:
Optional
[
str
]
=
None
def
_canonicalize
(
self
,
parents
):
super
().
_canonicalize
(
parents
)
if
self
.
password
is
not
None
:
self
.
ssh_key_file
=
None
self
.
gpu_indices
=
utils
.
canonical_gpu_indices
(
self
.
gpu_indices
)
def
_validate_canonical
(
self
):
super
().
_validate_canonical
()
assert
0
<
self
.
port
<
65536
assert
self
.
max_trial_number_per_gpu
>
0
utils
.
validate_gpu_indices
(
self
.
gpu_indices
)
if
self
.
password
is
not
None
:
warnings
.
warn
(
'SSH password will be exposed in web UI as plain text. We recommend to use SSH key file.'
)
elif
not
Path
(
self
.
ssh_key_file
).
is_file
():
raise
ValueError
(
f
'RemoteMachineConfig: You must either provide password or a valid SSH key file "
{
self
.
ssh_key_file
}
"'
)
@
dataclass
(
init
=
False
)
class
RemoteConfig
(
TrainingServiceConfig
):
platform
:
str
=
'remote'
machine_list
:
List
[
RemoteMachineConfig
]
reuse_mode
:
bool
=
True
def
_validate_canonical
(
self
):
super
().
_validate_canonical
()
if
not
self
.
machine_list
:
raise
ValueError
(
f
'RemoteConfig: must provide at least one machine in machine_list'
)
if
not
self
.
trial_gpu_number
and
any
(
machine
.
max_trial_number_per_gpu
!=
1
for
machine
in
self
.
machine_list
):
raise
ValueError
(
'RemoteConfig: max_trial_number_per_gpu does not work without trial_gpu_number'
)
nni/experiment/config/util.py
deleted
100644 → 0
View file @
cb090e8c
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
"""
Miscellaneous utility functions.
"""
import
importlib
import
json
import
math
import
os.path
from
pathlib
import
Path
from
typing
import
Any
,
Dict
,
Optional
,
Union
,
List
import
nni.runtime.config
PathLike
=
Union
[
Path
,
str
]
def
case_insensitive
(
key_or_kwargs
:
Union
[
str
,
Dict
[
str
,
Any
]])
->
Union
[
str
,
Dict
[
str
,
Any
]]:
if
isinstance
(
key_or_kwargs
,
str
):
return
key_or_kwargs
.
lower
().
replace
(
'_'
,
''
)
else
:
return
{
key
.
lower
().
replace
(
'_'
,
''
):
value
for
key
,
value
in
key_or_kwargs
.
items
()}
def
camel_case
(
key
:
str
)
->
str
:
words
=
key
.
strip
(
'_'
).
split
(
'_'
)
return
words
[
0
]
+
''
.
join
(
word
.
title
()
for
word
in
words
[
1
:])
def
canonical_path
(
path
:
Optional
[
PathLike
])
->
Optional
[
str
]:
# Path.resolve() does not work on Windows when file not exist, so use os.path instead
return
os
.
path
.
abspath
(
os
.
path
.
expanduser
(
path
))
if
path
is
not
None
else
None
def
count
(
*
values
)
->
int
:
return
sum
(
value
is
not
None
and
value
is
not
False
for
value
in
values
)
def
training_service_config_factory
(
platform
:
Union
[
str
,
List
[
str
]]
=
None
,
config
:
Union
[
List
,
Dict
]
=
None
,
base_path
:
Optional
[
Path
]
=
None
):
# -> TrainingServiceConfig
from
.common
import
TrainingServiceConfig
# import all custom config classes so they can be found in TrainingServiceConfig.__subclasses__()
custom_ts_config_path
=
nni
.
runtime
.
config
.
get_config_file
(
'training_services.json'
)
custom_ts_config
=
json
.
load
(
custom_ts_config_path
.
open
())
for
custom_ts_pkg
in
custom_ts_config
.
keys
():
pkg
=
importlib
.
import_module
(
custom_ts_pkg
)
_config_class
=
pkg
.
nni_training_service_info
.
config_class
ts_configs
=
[]
if
platform
is
not
None
:
assert
config
is
None
platforms
=
platform
if
isinstance
(
platform
,
list
)
else
[
platform
]
for
cls
in
TrainingServiceConfig
.
__subclasses__
():
if
cls
.
platform
in
platforms
:
ts_configs
.
append
(
cls
())
if
len
(
ts_configs
)
<
len
(
platforms
):
bad
=
', '
.
join
(
set
(
platforms
)
-
set
(
ts_configs
))
raise
RuntimeError
(
f
'Bad training service platform:
{
bad
}
'
)
else
:
assert
config
is
not
None
supported_platforms
=
{
cls
.
platform
:
cls
for
cls
in
TrainingServiceConfig
.
__subclasses__
()}
configs
=
config
if
isinstance
(
config
,
list
)
else
[
config
]
for
conf
in
configs
:
if
conf
[
'platform'
]
not
in
supported_platforms
:
raise
RuntimeError
(
f
'Unrecognized platform
{
conf
[
"platform"
]
}
'
)
ts_configs
.
append
(
supported_platforms
[
conf
[
'platform'
]](
_base_path
=
base_path
,
**
conf
))
return
ts_configs
if
len
(
ts_configs
)
>
1
else
ts_configs
[
0
]
def
load_config
(
Type
,
value
):
if
isinstance
(
value
,
list
):
return
[
load_config
(
Type
,
item
)
for
item
in
value
]
if
isinstance
(
value
,
dict
):
return
Type
(
**
value
)
return
value
def
strip_optional
(
type_hint
):
return
type_hint
.
__args__
[
0
]
if
str
(
type_hint
).
startswith
(
'typing.Optional['
)
else
type_hint
def
parse_time
(
time
:
str
,
target_unit
:
str
=
's'
)
->
int
:
return
_parse_unit
(
time
.
lower
(),
target_unit
,
_time_units
)
def
parse_size
(
size
:
str
,
target_unit
:
str
=
'mb'
)
->
int
:
return
_parse_unit
(
size
.
lower
(),
target_unit
,
_size_units
)
_time_units
=
{
'd'
:
24
*
3600
,
'h'
:
3600
,
'm'
:
60
,
's'
:
1
}
_size_units
=
{
'gb'
:
1024
*
1024
*
1024
,
'mb'
:
1024
*
1024
,
'kb'
:
1024
}
def
_parse_unit
(
string
,
target_unit
,
all_units
):
for
unit
,
factor
in
all_units
.
items
():
if
string
.
endswith
(
unit
):
number
=
string
[:
-
len
(
unit
)]
value
=
float
(
number
)
*
factor
return
math
.
ceil
(
value
/
all_units
[
target_unit
])
raise
ValueError
(
f
'Unsupported unit in "
{
string
}
"'
)
def
canonical_gpu_indices
(
indices
:
Union
[
List
[
int
],
str
,
int
,
None
])
->
Optional
[
List
[
int
]]:
if
isinstance
(
indices
,
str
):
return
[
int
(
idx
)
for
idx
in
indices
.
split
(
','
)]
if
isinstance
(
indices
,
int
):
return
[
indices
]
return
indices
nni/experiment/config/utils/__init__.py
0 → 100644
View file @
d5857823
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
"""
Utility functions for experiment config classes.
Check "public.py" to see which functions you can utilize.
"""
from
.public
import
*
from
.internal
import
*
nni/experiment/config/utils/internal.py
0 → 100644
View file @
d5857823
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
"""
Utility functions for experiment config classes, internal part.
If you are implementing a config class for a training service, it's unlikely you will need these.
"""
import
dataclasses
import
importlib
import
json
import
os.path
from
pathlib
import
Path
import
socket
import
typeguard
import
nni.runtime.config
from
.public
import
is_missing
## handle relative path ##
_current_base_path
=
None
def
get_base_path
():
if
_current_base_path
is
None
:
return
Path
()
return
_current_base_path
def
set_base_path
(
path
):
global
_current_base_path
assert
_current_base_path
is
None
_current_base_path
=
path
def
unset_base_path
():
global
_current_base_path
_current_base_path
=
None
def
resolve_path
(
path
,
base_path
):
if
path
is
None
:
return
None
# Path.resolve() does not work on Windows when file not exist, so use os.path instead
path
=
os
.
path
.
expanduser
(
path
)
if
not
os
.
path
.
isabs
(
path
):
path
=
os
.
path
.
join
(
base_path
,
path
)
return
str
(
os
.
path
.
realpath
(
path
))
# it should be already str, but official doc does not specify it's type
## field name case convertion ##
def
case_insensitive
(
key
):
return
key
.
lower
().
replace
(
'_'
,
''
)
def
camel_case
(
key
):
words
=
key
.
strip
(
'_'
).
split
(
'_'
)
return
words
[
0
]
+
''
.
join
(
word
.
title
()
for
word
in
words
[
1
:])
## type hint utils ##
def
is_instance
(
value
,
type_hint
):
try
:
typeguard
.
check_type
(
'_'
,
value
,
type_hint
)
except
TypeError
:
return
False
return
True
def
validate_type
(
config
):
class_name
=
type
(
config
).
__name__
for
field
in
dataclasses
.
fields
(
config
):
value
=
getattr
(
config
,
field
.
name
)
#check existense
if
is_missing
(
value
):
raise
ValueError
(
f
'
{
class_name
}
:
{
field
.
name
}
is not set'
)
if
not
is_instance
(
value
,
field
.
type
):
raise
ValueError
(
f
'
{
class_name
}
: type of
{
field
.
name
}
(
{
repr
(
value
)
}
) is not
{
field
.
type
}
'
)
def
is_path_like
(
type_hint
):
# only `PathLike` and `Any` accepts `Path`; check `int` to make sure it's not `Any`
return
is_instance
(
Path
(),
type_hint
)
and
not
is_instance
(
1
,
type_hint
)
## type inference ##
def
guess_config_type
(
obj
,
type_hint
):
ret
=
guess_list_config_type
([
obj
],
type_hint
,
_hint_list_item
=
True
)
return
ret
[
0
]
if
ret
else
None
def
guess_list_config_type
(
objs
,
type_hint
,
_hint_list_item
=
False
):
# avoid circular import
from
..base
import
ConfigBase
from
..training_service
import
TrainingServiceConfig
# because __init__ of subclasses might be complex, we first create empty objects to determine type
candidate_classes
=
[]
for
cls
in
_all_subclasses
(
ConfigBase
):
if
issubclass
(
cls
,
TrainingServiceConfig
):
# training service configs are specially handled
continue
empty_list
=
[
cls
.
__new__
(
cls
)]
if
_hint_list_item
:
good_type
=
is_instance
(
empty_list
[
0
],
type_hint
)
else
:
good_type
=
is_instance
(
empty_list
,
type_hint
)
if
good_type
:
candidate_classes
.
append
(
cls
)
if
not
candidate_classes
:
# it does not accept config type
return
None
if
len
(
candidate_classes
)
==
1
:
# the type is confirmed, raise error if cannot convert to this type
return
[
candidate_classes
[
0
](
**
obj
)
for
obj
in
objs
]
# multiple candidates available, call __init__ to further verify
candidate_configs
=
[]
for
cls
in
candidate_classes
:
try
:
configs
=
[
cls
(
**
obj
)
for
obj
in
objs
]
except
Exception
:
continue
candidate_configs
.
append
(
configs
)
if
not
candidate_configs
:
return
None
if
len
(
candidate_configs
)
==
1
:
return
candidate_configs
[
0
]
# still have multiple candidates, choose the common base class
for
base
in
candidate_configs
:
base_class
=
type
(
base
[
0
])
is_base
=
all
(
isinstance
(
configs
[
0
],
base_class
)
for
configs
in
candidate_configs
)
if
is_base
:
return
base
return
None
# cannot detect the type, give up
def
_all_subclasses
(
cls
):
subclasses
=
set
(
cls
.
__subclasses__
())
return
subclasses
.
union
(
*
[
_all_subclasses
(
subclass
)
for
subclass
in
subclasses
])
def
training_service_config_factory
(
platform
):
cls
=
_get_ts_config_class
(
platform
)
if
cls
is
None
:
raise
ValueError
(
f
'Bad training service platform:
{
platform
}
'
)
return
cls
()
def
load_training_service_config
(
config
):
if
isinstance
(
config
,
dict
)
and
'platform'
in
config
:
cls
=
_get_ts_config_class
(
config
[
'platform'
])
if
cls
is
not
None
:
return
cls
(
**
config
)
return
config
# not valid json, don't touch
def
_get_ts_config_class
(
platform
):
from
..training_service
import
TrainingServiceConfig
# avoid circular import
# import all custom config classes so they can be found in TrainingServiceConfig.__subclasses__()
custom_ts_config_path
=
nni
.
runtime
.
config
.
get_config_file
(
'training_services.json'
)
with
custom_ts_config_path
.
open
()
as
config_file
:
custom_ts_config
=
json
.
load
(
config_file
)
for
custom_ts_pkg
in
custom_ts_config
.
keys
():
pkg
=
importlib
.
import_module
(
custom_ts_pkg
)
_config_class
=
pkg
.
nni_training_service_info
.
config_class
for
cls
in
TrainingServiceConfig
.
__subclasses__
():
if
cls
.
platform
==
platform
:
return
cls
return
None
## misc ##
def
get_ipv4_address
():
s
=
socket
.
socket
(
socket
.
AF_INET
,
socket
.
SOCK_DGRAM
)
s
.
connect
((
'192.0.2.0'
,
80
))
addr
=
s
.
getsockname
()[
0
]
s
.
close
()
return
addr
nni/experiment/config/utils/public.py
0 → 100644
View file @
d5857823
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
"""
Utility functions for experiment config classes.
"""
import
dataclasses
import
math
from
pathlib
import
Path
from
typing
import
Union
PathLike
=
Union
[
Path
,
str
]
def
is_missing
(
value
):
"""
Used to check whether a dataclass field has ever been assigned.
If a field without default value has never been assigned, it will have a special value ``MISSING``.
This function checks if the parameter is ``MISSING``.
"""
# MISSING is not singleton and there is no official API to check it
return
isinstance
(
value
,
type
(
dataclasses
.
MISSING
))
def
canonical_gpu_indices
(
indices
):
"""
If ``indices`` is not None, cast it to list of int.
"""
if
isinstance
(
indices
,
str
):
return
[
int
(
idx
)
for
idx
in
indices
.
split
(
','
)]
if
isinstance
(
indices
,
int
):
return
[
indices
]
return
indices
def
validate_gpu_indices
(
indices
):
if
indices
is
None
:
return
if
len
(
set
(
indices
))
!=
len
(
indices
):
raise
ValueError
(
f
'Duplication detected in GPU indices
{
indices
}
'
)
if
any
(
idx
<
0
for
idx
in
indices
):
raise
ValueError
(
f
'Negative detected in GPU indices
{
indices
}
'
)
def
parse_time
(
value
):
"""
If ``value`` is a string, convert it to integral number of seconds.
"""
return
_parse_unit
(
value
,
's'
,
_time_units
)
def
parse_memory_size
(
value
):
"""
If ``value`` is a string, convert it to integral number of mega bytes.
"""
return
_parse_unit
(
value
,
'mb'
,
_size_units
)
_time_units
=
{
'd'
:
24
*
3600
,
'h'
:
3600
,
'm'
:
60
,
's'
:
1
}
_size_units
=
{
'tb'
:
1024
**
4
,
'gb'
:
1024
**
3
,
'mb'
:
1024
**
2
,
'kb'
:
1024
,
'b'
:
1
}
def
_parse_unit
(
value
,
target_unit
,
all_units
):
if
not
isinstance
(
value
,
str
):
return
value
value
=
value
.
lower
()
for
unit
,
factor
in
all_units
.
items
():
if
value
.
endswith
(
unit
):
number
=
value
[:
-
len
(
unit
)]
value
=
float
(
number
)
*
factor
return
math
.
ceil
(
value
/
all_units
[
target_unit
])
supported_units
=
', '
.
join
(
all_units
.
keys
())
raise
ValueError
(
f
'Bad unit in "
{
value
}
", supported units are
{
supported_units
}
'
)
nni/experiment/experiment.py
View file @
d5857823
import
atexit
import
atexit
from
enum
import
Enum
import
logging
import
logging
from
pathlib
import
Path
from
pathlib
import
Path
import
socket
import
socket
...
@@ -12,7 +13,7 @@ import psutil
...
@@ -12,7 +13,7 @@ import psutil
import
nni.runtime.log
import
nni.runtime.log
from
nni.common
import
dump
from
nni.common
import
dump
from
.config
import
ExperimentConfig
,
AlgorithmConfig
from
.config
import
ExperimentConfig
from
.data
import
TrialJob
,
TrialMetricData
,
TrialResult
from
.data
import
TrialJob
,
TrialMetricData
,
TrialResult
from
.
import
launcher
from
.
import
launcher
from
.
import
management
from
.
import
management
...
@@ -21,6 +22,17 @@ from ..tools.nnictl.command_utils import kill_command
...
@@ -21,6 +22,17 @@ from ..tools.nnictl.command_utils import kill_command
_logger
=
logging
.
getLogger
(
'nni.experiment'
)
_logger
=
logging
.
getLogger
(
'nni.experiment'
)
class
RunMode
(
Enum
):
"""
Config lifecycle and ouput redirection of NNI manager process.
- Background: stop NNI manager when Python script exits; do not print NNI manager log. (default)
- Foreground: stop NNI manager when Python script exits; print NNI manager log to stdout.
- Detach: do not stop NNI manager when Python script exits.
"""
Background
=
'background'
Foreground
=
'foreground'
Detach
=
'detach'
class
Experiment
:
class
Experiment
:
"""
"""
...
@@ -73,21 +85,19 @@ class Experiment:
...
@@ -73,21 +85,19 @@ class Experiment:
nni
.
runtime
.
log
.
init_logger_experiment
()
nni
.
runtime
.
log
.
init_logger_experiment
()
self
.
config
:
Optional
[
ExperimentConfig
]
=
None
self
.
config
:
Optional
[
ExperimentConfig
]
=
None
self
.
id
:
Optional
[
str
]
=
None
self
.
id
:
str
=
management
.
generate_experiment_id
()
self
.
port
:
Optional
[
int
]
=
None
self
.
port
:
Optional
[
int
]
=
None
self
.
_proc
:
Optional
[
Popen
]
=
None
self
.
_proc
:
Optional
[
Popen
]
=
None
self
.
mode
=
'new'
self
.
mode
=
'new'
self
.
url_prefix
:
Optional
[
str
]
=
None
args
=
[
config
,
training_service
]
# deal with overloading
args
=
[
config
,
training_service
]
# deal with overloading
if
isinstance
(
args
[
0
],
(
str
,
list
)):
if
isinstance
(
args
[
0
],
(
str
,
list
)):
self
.
config
=
ExperimentConfig
(
args
[
0
])
self
.
config
=
ExperimentConfig
(
args
[
0
])
self
.
config
.
tuner
=
AlgorithmConfig
(
name
=
'_none_'
,
class_args
=
{})
self
.
config
.
assessor
=
AlgorithmConfig
(
name
=
'_none_'
,
class_args
=
{})
self
.
config
.
advisor
=
AlgorithmConfig
(
name
=
'_none_'
,
class_args
=
{})
else
:
else
:
self
.
config
=
args
[
0
]
self
.
config
=
args
[
0
]
def
start
(
self
,
port
:
int
=
8080
,
debug
:
bool
=
False
)
->
None
:
def
start
(
self
,
port
:
int
=
8080
,
debug
:
bool
=
False
,
run_mode
:
RunMode
=
RunMode
.
Background
)
->
None
:
"""
"""
Start the experiment in background.
Start the experiment in background.
...
@@ -101,25 +111,25 @@ class Experiment:
...
@@ -101,25 +111,25 @@ class Experiment:
debug
debug
Whether to start in debug mode.
Whether to start in debug mode.
"""
"""
if
run_mode
is
not
RunMode
.
Detach
:
atexit
.
register
(
self
.
stop
)
atexit
.
register
(
self
.
stop
)
if
self
.
mode
==
'new'
:
config
=
self
.
config
.
canonical_copy
()
self
.
id
=
management
.
generate_experiment_id
()
if
config
.
use_annotation
:
else
:
raise
RuntimeError
(
'NNI annotation is not supported by Python experiment API.'
)
self
.
config
=
launcher
.
get_stopped_experiment_config
(
self
.
id
,
self
.
mode
)
if
self
.
config
.
experiment_working_directory
is
not
None
:
if
config
.
experiment_working_directory
is
not
None
:
log_dir
=
Path
(
self
.
config
.
experiment_working_directory
,
self
.
id
,
'log'
)
log_dir
=
Path
(
config
.
experiment_working_directory
,
self
.
id
,
'log'
)
else
:
else
:
# this should never happen in latest version, keep it until v2.7 for potential compatibility
log_dir
=
Path
.
home
()
/
f
'nni-experiments/
{
self
.
id
}
/log'
log_dir
=
Path
.
home
()
/
f
'nni-experiments/
{
self
.
id
}
/log'
nni
.
runtime
.
log
.
start_experiment_log
(
self
.
id
,
log_dir
,
debug
)
nni
.
runtime
.
log
.
start_experiment_log
(
self
.
id
,
log_dir
,
debug
)
self
.
_proc
=
launcher
.
start_experiment
(
self
.
id
,
self
.
config
,
port
,
debug
,
mode
=
self
.
mode
)
self
.
_proc
=
launcher
.
start_experiment
(
self
.
mode
,
self
.
id
,
config
,
port
,
debug
,
run_
mode
,
self
.
url_prefix
)
assert
self
.
_proc
is
not
None
assert
self
.
_proc
is
not
None
self
.
port
=
port
# port will be None if start up failed
self
.
port
=
port
# port will be None if start up failed
ips
=
[
self
.
config
.
nni_manager_ip
]
ips
=
[
config
.
nni_manager_ip
]
for
interfaces
in
psutil
.
net_if_addrs
().
values
():
for
interfaces
in
psutil
.
net_if_addrs
().
values
():
for
interface
in
interfaces
:
for
interface
in
interfaces
:
if
interface
.
family
==
socket
.
AF_INET
:
if
interface
.
family
==
socket
.
AF_INET
:
...
@@ -135,11 +145,10 @@ class Experiment:
...
@@ -135,11 +145,10 @@ class Experiment:
_logger
.
info
(
'Stopping experiment, please wait...'
)
_logger
.
info
(
'Stopping experiment, please wait...'
)
atexit
.
unregister
(
self
.
stop
)
atexit
.
unregister
(
self
.
stop
)
if
self
.
id
is
not
None
:
nni
.
runtime
.
log
.
stop_experiment_log
(
self
.
id
)
nni
.
runtime
.
log
.
stop_experiment_log
(
self
.
id
)
if
self
.
_proc
is
not
None
:
if
self
.
_proc
is
not
None
:
try
:
try
:
rest
.
delete
(
self
.
port
,
'/experiment'
)
rest
.
delete
(
self
.
port
,
'/experiment'
,
self
.
url_prefix
)
except
Exception
as
e
:
except
Exception
as
e
:
_logger
.
exception
(
e
)
_logger
.
exception
(
e
)
_logger
.
warning
(
'Cannot gracefully stop experiment, killing NNI process...'
)
_logger
.
warning
(
'Cannot gracefully stop experiment, killing NNI process...'
)
...
@@ -197,8 +206,8 @@ class Experiment:
...
@@ -197,8 +206,8 @@ class Experiment:
_logger
.
info
(
'Connect to port %d success, experiment id is %s, status is %s.'
,
port
,
experiment
.
id
,
status
)
_logger
.
info
(
'Connect to port %d success, experiment id is %s, status is %s.'
,
port
,
experiment
.
id
,
status
)
return
experiment
return
experiment
@
class
method
@
static
method
def
resume
(
cls
,
experiment_id
:
str
,
port
:
int
=
8080
,
wait_completion
:
bool
=
True
,
debug
:
bool
=
False
):
def
resume
(
experiment_id
:
str
,
port
:
int
=
8080
,
wait_completion
:
bool
=
True
,
debug
:
bool
=
False
):
"""
"""
Resume a stopped experiment.
Resume a stopped experiment.
...
@@ -213,15 +222,13 @@ class Experiment:
...
@@ -213,15 +222,13 @@ class Experiment:
debug
debug
Whether to start in debug mode.
Whether to start in debug mode.
"""
"""
experiment
=
Experiment
()
experiment
=
Experiment
.
_resume
(
experiment_id
)
experiment
.
id
=
experiment_id
experiment
.
mode
=
'resume'
experiment
.
run
(
port
=
port
,
wait_completion
=
wait_completion
,
debug
=
debug
)
experiment
.
run
(
port
=
port
,
wait_completion
=
wait_completion
,
debug
=
debug
)
if
not
wait_completion
:
if
not
wait_completion
:
return
experiment
return
experiment
@
class
method
@
static
method
def
view
(
cls
,
experiment_id
:
str
,
port
:
int
=
8080
,
non_blocking
:
bool
=
False
):
def
view
(
experiment_id
:
str
,
port
:
int
=
8080
,
non_blocking
:
bool
=
False
):
"""
"""
View a stopped experiment.
View a stopped experiment.
...
@@ -234,11 +241,8 @@ class Experiment:
...
@@ -234,11 +241,8 @@ class Experiment:
non_blocking
non_blocking
If false, run in the foreground. If true, run in the background.
If false, run in the foreground. If true, run in the background.
"""
"""
debug
=
False
experiment
=
Experiment
.
_view
(
experiment_id
)
experiment
=
Experiment
()
experiment
.
start
(
port
=
port
,
debug
=
False
)
experiment
.
id
=
experiment_id
experiment
.
mode
=
'view'
experiment
.
start
(
port
=
port
,
debug
=
debug
)
if
non_blocking
:
if
non_blocking
:
return
experiment
return
experiment
else
:
else
:
...
@@ -250,6 +254,22 @@ class Experiment:
...
@@ -250,6 +254,22 @@ class Experiment:
finally
:
finally
:
experiment
.
stop
()
experiment
.
stop
()
@
staticmethod
def
_resume
(
exp_id
,
exp_dir
=
None
):
exp
=
Experiment
()
exp
.
id
=
exp_id
exp
.
mode
=
'resume'
exp
.
config
=
launcher
.
get_stopped_experiment_config
(
exp_id
,
exp_dir
)
return
exp
@
staticmethod
def
_view
(
exp_id
,
exp_dir
=
None
):
exp
=
Experiment
()
exp
.
id
=
exp_id
exp
.
mode
=
'view'
exp
.
config
=
launcher
.
get_stopped_experiment_config
(
exp_id
,
exp_dir
)
return
exp
def
get_status
(
self
)
->
str
:
def
get_status
(
self
)
->
str
:
"""
"""
Return experiment status as a str.
Return experiment status as a str.
...
@@ -259,7 +279,7 @@ class Experiment:
...
@@ -259,7 +279,7 @@ class Experiment:
str
str
Experiment status.
Experiment status.
"""
"""
resp
=
rest
.
get
(
self
.
port
,
'/check-status'
)
resp
=
rest
.
get
(
self
.
port
,
'/check-status'
,
self
.
url_prefix
)
return
resp
[
'status'
]
return
resp
[
'status'
]
def
get_trial_job
(
self
,
trial_job_id
:
str
):
def
get_trial_job
(
self
,
trial_job_id
:
str
):
...
@@ -276,7 +296,7 @@ class Experiment:
...
@@ -276,7 +296,7 @@ class Experiment:
TrialJob
TrialJob
A `TrialJob` instance corresponding to `trial_job_id`.
A `TrialJob` instance corresponding to `trial_job_id`.
"""
"""
resp
=
rest
.
get
(
self
.
port
,
'/trial-jobs/{}'
.
format
(
trial_job_id
))
resp
=
rest
.
get
(
self
.
port
,
'/trial-jobs/{}'
.
format
(
trial_job_id
)
,
self
.
url_prefix
)
return
TrialJob
(
**
resp
)
return
TrialJob
(
**
resp
)
def
list_trial_jobs
(
self
):
def
list_trial_jobs
(
self
):
...
@@ -288,7 +308,7 @@ class Experiment:
...
@@ -288,7 +308,7 @@ class Experiment:
list
list
List of `TrialJob`.
List of `TrialJob`.
"""
"""
resp
=
rest
.
get
(
self
.
port
,
'/trial-jobs'
)
resp
=
rest
.
get
(
self
.
port
,
'/trial-jobs'
,
self
.
url_prefix
)
return
[
TrialJob
(
**
trial_job
)
for
trial_job
in
resp
]
return
[
TrialJob
(
**
trial_job
)
for
trial_job
in
resp
]
def
get_job_statistics
(
self
):
def
get_job_statistics
(
self
):
...
@@ -300,7 +320,7 @@ class Experiment:
...
@@ -300,7 +320,7 @@ class Experiment:
dict
dict
Job statistics information.
Job statistics information.
"""
"""
resp
=
rest
.
get
(
self
.
port
,
'/job-statistics'
)
resp
=
rest
.
get
(
self
.
port
,
'/job-statistics'
,
self
.
url_prefix
)
return
resp
return
resp
def
get_job_metrics
(
self
,
trial_job_id
=
None
):
def
get_job_metrics
(
self
,
trial_job_id
=
None
):
...
@@ -318,7 +338,7 @@ class Experiment:
...
@@ -318,7 +338,7 @@ class Experiment:
Each key is a trialJobId, the corresponding value is a list of `TrialMetricData`.
Each key is a trialJobId, the corresponding value is a list of `TrialMetricData`.
"""
"""
api
=
'/metric-data/{}'
.
format
(
trial_job_id
)
if
trial_job_id
else
'/metric-data'
api
=
'/metric-data/{}'
.
format
(
trial_job_id
)
if
trial_job_id
else
'/metric-data'
resp
=
rest
.
get
(
self
.
port
,
api
)
resp
=
rest
.
get
(
self
.
port
,
api
,
self
.
url_prefix
)
metric_dict
=
{}
metric_dict
=
{}
for
metric
in
resp
:
for
metric
in
resp
:
trial_id
=
metric
[
"trialJobId"
]
trial_id
=
metric
[
"trialJobId"
]
...
@@ -337,7 +357,7 @@ class Experiment:
...
@@ -337,7 +357,7 @@ class Experiment:
dict
dict
The profile of the experiment.
The profile of the experiment.
"""
"""
resp
=
rest
.
get
(
self
.
port
,
'/experiment'
)
resp
=
rest
.
get
(
self
.
port
,
'/experiment'
,
self
.
url_prefix
)
return
resp
return
resp
def
get_experiment_metadata
(
self
,
exp_id
:
str
):
def
get_experiment_metadata
(
self
,
exp_id
:
str
):
...
@@ -364,7 +384,7 @@ class Experiment:
...
@@ -364,7 +384,7 @@ class Experiment:
list
list
The experiments metadata.
The experiments metadata.
"""
"""
resp
=
rest
.
get
(
self
.
port
,
'/experiments-info'
)
resp
=
rest
.
get
(
self
.
port
,
'/experiments-info'
,
self
.
url_prefix
)
return
resp
return
resp
def
export_data
(
self
):
def
export_data
(
self
):
...
@@ -376,7 +396,7 @@ class Experiment:
...
@@ -376,7 +396,7 @@ class Experiment:
list
list
List of `TrialResult`.
List of `TrialResult`.
"""
"""
resp
=
rest
.
get
(
self
.
port
,
'/export-data'
)
resp
=
rest
.
get
(
self
.
port
,
'/export-data'
,
self
.
url_prefix
)
return
[
TrialResult
(
**
trial_result
)
for
trial_result
in
resp
]
return
[
TrialResult
(
**
trial_result
)
for
trial_result
in
resp
]
def
_get_query_type
(
self
,
key
:
str
):
def
_get_query_type
(
self
,
key
:
str
):
...
@@ -403,7 +423,7 @@ class Experiment:
...
@@ -403,7 +423,7 @@ class Experiment:
api
=
'/experiment{}'
.
format
(
self
.
_get_query_type
(
key
))
api
=
'/experiment{}'
.
format
(
self
.
_get_query_type
(
key
))
experiment_profile
=
self
.
get_experiment_profile
()
experiment_profile
=
self
.
get_experiment_profile
()
experiment_profile
[
'params'
][
key
]
=
value
experiment_profile
[
'params'
][
key
]
=
value
rest
.
put
(
self
.
port
,
api
,
experiment_profile
)
rest
.
put
(
self
.
port
,
api
,
experiment_profile
,
self
.
url_prefix
)
logging
.
info
(
'Successfully update %s.'
,
key
)
logging
.
info
(
'Successfully update %s.'
,
key
)
def
update_trial_concurrency
(
self
,
value
:
int
):
def
update_trial_concurrency
(
self
,
value
:
int
):
...
...
nni/experiment/launcher.py
View file @
d5857823
...
@@ -2,7 +2,10 @@
...
@@ -2,7 +2,10 @@
# Licensed under the MIT license.
# Licensed under the MIT license.
import
contextlib
import
contextlib
from
dataclasses
import
dataclass
,
fields
from
datetime
import
datetime
import
logging
import
logging
import
os.path
from
pathlib
import
Path
from
pathlib
import
Path
import
socket
import
socket
from
subprocess
import
Popen
from
subprocess
import
Popen
...
@@ -23,29 +26,89 @@ from ..tools.nnictl.nnictl_utils import update_experiment
...
@@ -23,29 +26,89 @@ from ..tools.nnictl.nnictl_utils import update_experiment
_logger
=
logging
.
getLogger
(
'nni.experiment'
)
_logger
=
logging
.
getLogger
(
'nni.experiment'
)
@
dataclass
(
init
=
False
)
class
NniManagerArgs
:
port
:
int
experiment_id
:
int
start_mode
:
str
# new or resume
mode
:
str
# training service platform
log_dir
:
str
log_level
:
str
readonly
:
bool
=
False
foreground
:
bool
=
False
url_prefix
:
Optional
[
str
]
=
None
dispatcher_pipe
:
Optional
[
str
]
=
None
def
start_experiment
(
exp_id
:
str
,
config
:
ExperimentConfig
,
port
:
int
,
debug
:
bool
,
mode
:
str
=
'new'
)
->
Popen
:
def
__init__
(
self
,
action
,
exp_id
,
config
,
port
,
debug
,
foreground
,
url_prefix
):
proc
=
None
self
.
port
=
port
self
.
experiment_id
=
exp_id
self
.
foreground
=
foreground
self
.
url_prefix
=
url_prefix
self
.
log_dir
=
config
.
experiment_working_directory
config
.
validate
(
initialized_tuner
=
False
)
if
isinstance
(
config
.
training_service
,
list
):
_ensure_port_idle
(
port
)
self
.
mode
=
'hybrid'
else
:
self
.
mode
=
config
.
training_service
.
platform
if
mode
!=
'view'
:
self
.
log_level
=
config
.
log_level
if
isinstance
(
config
.
training_service
,
list
):
# hybrid training service
if
debug
and
self
.
log_level
not
in
[
'debug'
,
'trace'
]:
_ensure_port_idle
(
port
+
1
,
'Hybrid training service requires an additional port'
)
self
.
log_level
=
'debug'
elif
config
.
training_service
.
platform
in
[
'remote'
,
'openpai'
,
'kubeflow'
,
'frameworkcontroller'
,
'adl'
]:
_ensure_port_idle
(
port
+
1
,
f
'
{
config
.
training_service
.
platform
}
requires an additional port'
)
if
action
==
'resume'
:
self
.
start_mode
=
'resume'
elif
action
==
'view'
:
self
.
start_mode
=
'resume'
self
.
readonly
=
True
else
:
self
.
start_mode
=
'new'
def
to_command_line_args
(
self
):
ret
=
[]
for
field
in
fields
(
self
):
value
=
getattr
(
self
,
field
.
name
)
if
value
is
not
None
:
ret
.
append
(
'--'
+
field
.
name
)
if
isinstance
(
value
,
bool
):
ret
.
append
(
str
(
value
).
lower
())
else
:
ret
.
append
(
str
(
value
))
return
ret
def
start_experiment
(
action
,
exp_id
,
config
,
port
,
debug
,
run_mode
,
url_prefix
):
foreground
=
run_mode
.
value
==
'foreground'
nni_manager_args
=
NniManagerArgs
(
action
,
exp_id
,
config
,
port
,
debug
,
foreground
,
url_prefix
)
_ensure_port_idle
(
port
)
websocket_platforms
=
[
'hybrid'
,
'remote'
,
'openpai'
,
'kubeflow'
,
'frameworkcontroller'
,
'adl'
]
if
action
!=
'view'
and
nni_manager_args
.
mode
in
websocket_platforms
:
_ensure_port_idle
(
port
+
1
,
f
'
{
nni_manager_args
.
mode
}
requires an additional port'
)
proc
=
None
try
:
try
:
_logger
.
info
(
'Creating experiment, Experiment ID: %s'
,
colorama
.
Fore
.
CYAN
+
exp_id
+
colorama
.
Style
.
RESET_ALL
)
_logger
.
info
(
start_time
,
proc
=
_start_rest_server
(
config
,
port
,
debug
,
exp_id
,
mode
=
mode
)
'Creating experiment, Experiment ID: %s'
,
colorama
.
Fore
.
CYAN
+
exp_id
+
colorama
.
Style
.
RESET_ALL
)
proc
=
_start_rest_server
(
nni_manager_args
,
run_mode
)
start_time
=
int
(
time
.
time
()
*
1000
)
_logger
.
info
(
'Starting web server...'
)
_logger
.
info
(
'Starting web server...'
)
_check_rest_server
(
port
)
_check_rest_server
(
port
,
url_prefix
=
url_prefix
)
platform
=
'hybrid'
if
isinstance
(
config
.
training_service
,
list
)
else
config
.
training_service
.
platform
_save_experiment_information
(
exp_id
,
port
,
start_time
,
platform
,
Experiments
().
add_experiment
(
config
.
experiment_name
,
proc
.
pid
,
str
(
config
.
experiment_working_directory
),
[])
exp_id
,
port
,
start_time
,
nni_manager_args
.
mode
,
config
.
experiment_name
,
pid
=
proc
.
pid
,
logDir
=
config
.
experiment_working_directory
,
tag
=
[],
)
_logger
.
info
(
'Setting up...'
)
_logger
.
info
(
'Setting up...'
)
rest
.
post
(
port
,
'/experiment'
,
config
.
json
())
rest
.
post
(
port
,
'/experiment'
,
config
.
json
(),
url_prefix
)
return
proc
return
proc
except
Exception
as
e
:
except
Exception
as
e
:
...
@@ -55,6 +118,33 @@ def start_experiment(exp_id: str, config: ExperimentConfig, port: int, debug: bo
...
@@ -55,6 +118,33 @@ def start_experiment(exp_id: str, config: ExperimentConfig, port: int, debug: bo
proc
.
kill
()
proc
.
kill
()
raise
e
raise
e
def
_start_rest_server
(
nni_manager_args
,
run_mode
)
->
Tuple
[
int
,
Popen
]:
node_dir
=
Path
(
nni_node
.
__path__
[
0
])
node
=
str
(
node_dir
/
(
'node.exe'
if
sys
.
platform
==
'win32'
else
'node'
))
main_js
=
str
(
node_dir
/
'main.js'
)
cmd
=
[
node
,
'--max-old-space-size=4096'
,
main_js
]
cmd
+=
nni_manager_args
.
to_command_line_args
()
if
run_mode
.
value
==
'detach'
:
log
=
Path
(
nni_manager_args
.
log_dir
,
nni_manager_args
.
experiment_id
,
'log'
)
out
=
(
log
/
'nnictl_stdout.log'
).
open
(
'a'
)
err
=
(
log
/
'nnictl_stderr.log'
).
open
(
'a'
)
header
=
f
'Experiment
{
nni_manager_args
.
experiment_id
}
start:
{
datetime
.
now
()
}
'
header
=
'-'
*
80
+
'
\n
'
+
header
+
'
\n
'
+
'-'
*
80
+
'
\n
'
out
.
write
(
header
)
err
.
write
(
header
)
else
:
out
=
None
err
=
None
if
sys
.
platform
==
'win32'
:
from
subprocess
import
CREATE_NEW_PROCESS_GROUP
return
Popen
(
cmd
,
stdout
=
out
,
stderr
=
err
,
cwd
=
node_dir
,
creationflags
=
CREATE_NEW_PROCESS_GROUP
)
else
:
return
Popen
(
cmd
,
stdout
=
out
,
stderr
=
err
,
cwd
=
node_dir
,
preexec_fn
=
os
.
setpgrp
)
def
start_experiment_retiarii
(
exp_id
:
str
,
config
:
ExperimentConfig
,
port
:
int
,
debug
:
bool
)
->
Popen
:
def
start_experiment_retiarii
(
exp_id
:
str
,
config
:
ExperimentConfig
,
port
:
int
,
debug
:
bool
)
->
Popen
:
pipe
=
None
pipe
=
None
proc
=
None
proc
=
None
...
@@ -69,7 +159,7 @@ def start_experiment_retiarii(exp_id: str, config: ExperimentConfig, port: int,
...
@@ -69,7 +159,7 @@ def start_experiment_retiarii(exp_id: str, config: ExperimentConfig, port: int,
try
:
try
:
_logger
.
info
(
'Creating experiment, Experiment ID: %s'
,
colorama
.
Fore
.
CYAN
+
exp_id
+
colorama
.
Style
.
RESET_ALL
)
_logger
.
info
(
'Creating experiment, Experiment ID: %s'
,
colorama
.
Fore
.
CYAN
+
exp_id
+
colorama
.
Style
.
RESET_ALL
)
pipe
=
Pipe
(
exp_id
)
pipe
=
Pipe
(
exp_id
)
start_time
,
proc
=
_start_rest_server
(
config
,
port
,
debug
,
exp_id
,
pipe
.
path
)
start_time
,
proc
=
_start_rest_server
_retiarii
(
config
,
port
,
debug
,
exp_id
,
pipe
.
path
)
_logger
.
info
(
'Connecting IPC pipe...'
)
_logger
.
info
(
'Connecting IPC pipe...'
)
pipe_file
=
pipe
.
connect
()
pipe_file
=
pipe
.
connect
()
nni
.
runtime
.
protocol
.
_in_file
=
pipe_file
nni
.
runtime
.
protocol
.
_in_file
=
pipe_file
...
@@ -101,8 +191,8 @@ def _ensure_port_idle(port: int, message: Optional[str] = None) -> None:
...
@@ -101,8 +191,8 @@ def _ensure_port_idle(port: int, message: Optional[str] = None) -> None:
raise
RuntimeError
(
f
'Port
{
port
}
is not idle
{
message
}
'
)
raise
RuntimeError
(
f
'Port
{
port
}
is not idle
{
message
}
'
)
def
_start_rest_server
(
config
:
ExperimentConfig
,
port
:
int
,
debug
:
bool
,
experiment_id
:
str
,
pipe_path
:
str
=
None
,
def
_start_rest_server
_retiarii
(
config
:
ExperimentConfig
,
port
:
int
,
debug
:
bool
,
experiment_id
:
str
,
mode
:
str
=
'new'
)
->
Tuple
[
int
,
Popen
]:
pipe_path
:
str
=
None
,
mode
:
str
=
'new'
)
->
Tuple
[
int
,
Popen
]:
if
isinstance
(
config
.
training_service
,
list
):
if
isinstance
(
config
.
training_service
,
list
):
ts
=
'hybrid'
ts
=
'hybrid'
else
:
else
:
...
@@ -145,15 +235,15 @@ def _start_rest_server(config: ExperimentConfig, port: int, debug: bool, experim
...
@@ -145,15 +235,15 @@ def _start_rest_server(config: ExperimentConfig, port: int, debug: bool, experim
return
int
(
time
.
time
()
*
1000
),
proc
return
int
(
time
.
time
()
*
1000
),
proc
def
_check_rest_server
(
port
:
int
,
retry
:
int
=
3
)
->
None
:
def
_check_rest_server
(
port
:
int
,
retry
:
int
=
3
,
url_prefix
:
Optional
[
str
]
=
None
)
->
None
:
for
i
in
range
(
retry
):
for
i
in
range
(
retry
):
with
contextlib
.
suppress
(
Exception
):
with
contextlib
.
suppress
(
Exception
):
rest
.
get
(
port
,
'/check-status'
)
rest
.
get
(
port
,
'/check-status'
,
url_prefix
)
return
return
if
i
>
0
:
if
i
>
0
:
_logger
.
warning
(
'Timeout, retry...'
)
_logger
.
warning
(
'Timeout, retry...'
)
time
.
sleep
(
1
)
time
.
sleep
(
1
)
rest
.
get
(
port
,
'/check-status'
)
rest
.
get
(
port
,
'/check-status'
,
url_prefix
)
def
_save_experiment_information
(
experiment_id
:
str
,
port
:
int
,
start_time
:
int
,
platform
:
str
,
def
_save_experiment_information
(
experiment_id
:
str
,
port
:
int
,
start_time
:
int
,
platform
:
str
,
...
@@ -162,7 +252,16 @@ def _save_experiment_information(experiment_id: str, port: int, start_time: int,
...
@@ -162,7 +252,16 @@ def _save_experiment_information(experiment_id: str, port: int, start_time: int,
experiments_config
.
add_experiment
(
experiment_id
,
port
,
start_time
,
platform
,
name
,
pid
=
pid
,
logDir
=
logDir
,
tag
=
tag
)
experiments_config
.
add_experiment
(
experiment_id
,
port
,
start_time
,
platform
,
name
,
pid
=
pid
,
logDir
=
logDir
,
tag
=
tag
)
def
get_stopped_experiment_config
(
exp_id
:
str
,
mode
:
str
)
->
None
:
def
get_stopped_experiment_config
(
exp_id
,
exp_dir
=
None
):
if
exp_dir
:
exp_config
=
Config
(
exp_id
,
exp_dir
).
get_config
()
config
=
ExperimentConfig
(
**
exp_config
)
if
not
os
.
path
.
samefile
(
exp_dir
,
config
.
experiment_working_directory
):
msg
=
'Experiment working directory provided in command line (%s) is different from experiment config (%s)'
_logger
.
warning
(
msg
,
exp_dir
,
config
.
experiment_working_directory
)
config
.
experiment_working_directory
=
exp_dir
return
config
else
:
update_experiment
()
update_experiment
()
experiments_config
=
Experiments
()
experiments_config
=
Experiments
()
experiments_dict
=
experiments_config
.
get_all_experiments
()
experiments_dict
=
experiments_config
.
get_all_experiments
()
...
@@ -171,8 +270,7 @@ def get_stopped_experiment_config(exp_id: str, mode: str) -> None:
...
@@ -171,8 +270,7 @@ def get_stopped_experiment_config(exp_id: str, mode: str) -> None:
_logger
.
error
(
'Id %s not exist!'
,
exp_id
)
_logger
.
error
(
'Id %s not exist!'
,
exp_id
)
return
return
if
experiment_metadata
[
'status'
]
!=
'STOPPED'
:
if
experiment_metadata
[
'status'
]
!=
'STOPPED'
:
_logger
.
error
(
'Only stopped experiments can be
%sed!'
,
mode
)
_logger
.
error
(
'Only stopped experiments can be
resumed or viewed!'
)
return
return
experiment_config
=
Config
(
exp_id
,
experiment_metadata
[
'logDir'
]).
get_config
()
experiment_config
=
Config
(
exp_id
,
experiment_metadata
[
'logDir'
]).
get_config
()
config
=
ExperimentConfig
(
**
experiment_config
)
return
ExperimentConfig
(
**
experiment_config
)
return
config
nni/experiment/rest.py
View file @
d5857823
...
@@ -5,31 +5,40 @@ import requests
...
@@ -5,31 +5,40 @@ import requests
_logger
=
logging
.
getLogger
(
__name__
)
_logger
=
logging
.
getLogger
(
__name__
)
url_template
=
'http://localhost:{}/api/v1/nni{}'
timeout
=
20
timeout
=
20
def
request
(
method
:
str
,
port
:
Optional
[
int
],
api
:
str
,
data
:
Any
=
None
)
->
Any
:
def
request
(
method
:
str
,
port
:
Optional
[
int
],
api
:
str
,
data
:
Any
=
None
,
prefix
:
Optional
[
str
]
=
None
)
->
Any
:
if
port
is
None
:
if
port
is
None
:
raise
RuntimeError
(
'Experiment is not running'
)
raise
RuntimeError
(
'Experiment is not running'
)
url
=
url_template
.
format
(
port
,
api
)
url_parts
=
[
f
'http://localhost:
{
port
}
'
,
prefix
,
'api/v1/nni'
,
api
]
url
=
'/'
.
join
(
part
.
strip
(
'/'
)
for
part
in
url_parts
if
part
)
if
data
is
None
:
if
data
is
None
:
resp
=
requests
.
request
(
method
,
url
,
timeout
=
timeout
)
resp
=
requests
.
request
(
method
,
url
,
timeout
=
timeout
)
else
:
else
:
resp
=
requests
.
request
(
method
,
url
,
json
=
data
,
timeout
=
timeout
)
resp
=
requests
.
request
(
method
,
url
,
json
=
data
,
timeout
=
timeout
)
if
not
resp
.
ok
:
if
not
resp
.
ok
:
_logger
.
error
(
'rest request %s %s failed: %s %s'
,
method
.
upper
(),
url
,
resp
.
status_code
,
resp
.
text
)
_logger
.
error
(
'rest request %s %s failed: %s %s'
,
method
.
upper
(),
url
,
resp
.
status_code
,
resp
.
text
)
resp
.
raise_for_status
()
resp
.
raise_for_status
()
if
method
.
lower
()
in
[
'get'
,
'post'
]
and
len
(
resp
.
content
)
>
0
:
if
method
.
lower
()
in
[
'get'
,
'post'
]
and
len
(
resp
.
content
)
>
0
:
return
resp
.
json
()
return
resp
.
json
()
def
get
(
port
:
Optional
[
int
],
api
:
str
)
->
Any
:
def
get
(
port
:
Optional
[
int
],
api
:
str
,
prefix
:
Optional
[
str
]
=
None
)
->
Any
:
return
request
(
'get'
,
port
,
api
)
return
request
(
'get'
,
port
,
api
,
prefix
=
prefix
)
def
post
(
port
:
Optional
[
int
],
api
:
str
,
data
:
Any
)
->
Any
:
def
post
(
port
:
Optional
[
int
],
api
:
str
,
data
:
Any
,
prefix
:
Optional
[
str
]
=
None
)
->
Any
:
return
request
(
'post'
,
port
,
api
,
data
)
return
request
(
'post'
,
port
,
api
,
data
,
prefix
=
prefix
)
def
put
(
port
:
Optional
[
int
],
api
:
str
,
data
:
Any
)
->
None
:
def
put
(
port
:
Optional
[
int
],
api
:
str
,
data
:
Any
,
prefix
:
Optional
[
str
]
=
None
)
->
None
:
request
(
'put'
,
port
,
api
,
data
)
request
(
'put'
,
port
,
api
,
data
,
prefix
=
prefix
)
def
delete
(
port
:
Optional
[
int
],
api
:
str
)
->
None
:
def
delete
(
port
:
Optional
[
int
],
api
:
str
,
prefix
:
Optional
[
str
]
=
None
)
->
None
:
request
(
'delete'
,
port
,
api
)
request
(
'delete'
,
port
,
api
,
prefix
=
prefix
)
nni/retiarii/experiment/pytorch.py
View file @
d5857823
...
@@ -18,9 +18,10 @@ import torch
...
@@ -18,9 +18,10 @@ import torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
nni.runtime.log
import
nni.runtime.log
from
nni.common.device
import
GPUDevice
from
nni.common.device
import
GPUDevice
from
nni.experiment
import
Experiment
,
TrainingServiceConfig
,
launcher
,
management
,
rest
from
nni.experiment
import
Experiment
,
launcher
,
management
,
rest
from
nni.experiment.config
import
util
from
nni.experiment.config
import
utils
from
nni.experiment.config.base
import
ConfigBase
,
PathLike
from
nni.experiment.config.base
import
ConfigBase
from
nni.experiment.config.training_service
import
TrainingServiceConfig
from
nni.experiment.pipe
import
Pipe
from
nni.experiment.pipe
import
Pipe
from
nni.tools.nnictl.command_utils
import
kill_command
from
nni.tools.nnictl.command_utils
import
kill_command
...
@@ -45,7 +46,7 @@ class RetiariiExeConfig(ConfigBase):
...
@@ -45,7 +46,7 @@ class RetiariiExeConfig(ConfigBase):
experiment_name
:
Optional
[
str
]
=
None
experiment_name
:
Optional
[
str
]
=
None
search_space
:
Any
=
''
# TODO: remove
search_space
:
Any
=
''
# TODO: remove
trial_command
:
str
=
'_reserved'
trial_command
:
str
=
'_reserved'
trial_code_directory
:
PathLike
=
'.'
trial_code_directory
:
utils
.
PathLike
=
'.'
trial_concurrency
:
int
trial_concurrency
:
int
trial_gpu_number
:
int
=
0
trial_gpu_number
:
int
=
0
devices
:
Optional
[
List
[
Union
[
str
,
GPUDevice
]]]
=
None
devices
:
Optional
[
List
[
Union
[
str
,
GPUDevice
]]]
=
None
...
@@ -56,7 +57,7 @@ class RetiariiExeConfig(ConfigBase):
...
@@ -56,7 +57,7 @@ class RetiariiExeConfig(ConfigBase):
nni_manager_ip
:
Optional
[
str
]
=
None
nni_manager_ip
:
Optional
[
str
]
=
None
debug
:
bool
=
False
debug
:
bool
=
False
log_level
:
Optional
[
str
]
=
None
log_level
:
Optional
[
str
]
=
None
experiment_working_directory
:
PathLike
=
'~/nni-experiments'
experiment_working_directory
:
utils
.
PathLike
=
'~/nni-experiments'
# remove configuration of tuner/assessor/advisor
# remove configuration of tuner/assessor/advisor
training_service
:
TrainingServiceConfig
training_service
:
TrainingServiceConfig
execution_engine
:
str
=
'py'
execution_engine
:
str
=
'py'
...
@@ -71,7 +72,7 @@ class RetiariiExeConfig(ConfigBase):
...
@@ -71,7 +72,7 @@ class RetiariiExeConfig(ConfigBase):
super
().
__init__
(
**
kwargs
)
super
().
__init__
(
**
kwargs
)
if
training_service_platform
is
not
None
:
if
training_service_platform
is
not
None
:
assert
'training_service'
not
in
kwargs
assert
'training_service'
not
in
kwargs
self
.
training_service
=
util
.
training_service_config_factory
(
platform
=
training_service_platform
)
self
.
training_service
=
util
s
.
training_service_config_factory
(
platform
=
training_service_platform
)
self
.
__dict__
[
'trial_command'
]
=
'python3 -m nni.retiarii.trial_entry py'
self
.
__dict__
[
'trial_command'
]
=
'python3 -m nni.retiarii.trial_entry py'
def
__setattr__
(
self
,
key
,
value
):
def
__setattr__
(
self
,
key
,
value
):
...
@@ -100,16 +101,12 @@ class RetiariiExeConfig(ConfigBase):
...
@@ -100,16 +101,12 @@ class RetiariiExeConfig(ConfigBase):
_canonical_rules
=
{
_canonical_rules
=
{
'trial_code_directory'
:
util
.
canonical_path
,
'max_experiment_duration'
:
lambda
value
:
f
'
{
util
.
parse_time
(
value
)
}
s'
if
value
is
not
None
else
None
,
'experiment_working_directory'
:
util
.
canonical_path
}
}
_validation_rules
=
{
_validation_rules
=
{
'trial_code_directory'
:
lambda
value
:
(
Path
(
value
).
is_dir
(),
f
'"
{
value
}
" does not exist or is not directory'
),
'trial_code_directory'
:
lambda
value
:
(
Path
(
value
).
is_dir
(),
f
'"
{
value
}
" does not exist or is not directory'
),
'trial_concurrency'
:
lambda
value
:
value
>
0
,
'trial_concurrency'
:
lambda
value
:
value
>
0
,
'trial_gpu_number'
:
lambda
value
:
value
>=
0
,
'trial_gpu_number'
:
lambda
value
:
value
>=
0
,
'max_experiment_duration'
:
lambda
value
:
util
.
parse_time
(
value
)
>
0
,
'max_trial_number'
:
lambda
value
:
value
>
0
,
'max_trial_number'
:
lambda
value
:
value
>
0
,
'log_level'
:
lambda
value
:
value
in
[
"trace"
,
"debug"
,
"info"
,
"warning"
,
"error"
,
"fatal"
],
'log_level'
:
lambda
value
:
value
in
[
"trace"
,
"debug"
,
"info"
,
"warning"
,
"error"
,
"fatal"
],
'training_service'
:
lambda
value
:
(
type
(
value
)
is
not
TrainingServiceConfig
,
'cannot be abstract base class'
)
'training_service'
:
lambda
value
:
(
type
(
value
)
is
not
TrainingServiceConfig
,
'cannot be abstract base class'
)
...
...
nni/runtime/log.py
View file @
d5857823
...
@@ -66,7 +66,9 @@ def start_experiment_log(experiment_id: str, log_directory: Path, debug: bool) -
...
@@ -66,7 +66,9 @@ def start_experiment_log(experiment_id: str, log_directory: Path, debug: bool) -
def
stop_experiment_log
(
experiment_id
:
str
)
->
None
:
def
stop_experiment_log
(
experiment_id
:
str
)
->
None
:
if
experiment_id
in
handlers
:
if
experiment_id
in
handlers
:
logging
.
getLogger
().
removeHandler
(
handlers
.
pop
(
experiment_id
))
handler
=
handlers
.
pop
(
experiment_id
,
None
)
if
handler
is
not
None
:
logging
.
getLogger
().
removeHandler
(
handler
)
def
_init_logger_dispatcher
()
->
None
:
def
_init_logger_dispatcher
()
->
None
:
...
...
nni/tools/nnictl/launcher.py
View file @
d5857823
This diff is collapsed.
Click to expand it.
nni/tools/nnictl/legacy_launcher.py
0 → 100644
View file @
d5857823
This diff is collapsed.
Click to expand it.
nni/tools/nnictl/nnictl.py
View file @
d5857823
...
@@ -62,7 +62,7 @@ def parse_args():
...
@@ -62,7 +62,7 @@ def parse_args():
# parse resume command
# parse resume command
parser_resume
=
subparsers
.
add_parser
(
'resume'
,
help
=
'resume a new experiment'
)
parser_resume
=
subparsers
.
add_parser
(
'resume'
,
help
=
'resume a new experiment'
)
parser_resume
.
add_argument
(
'id'
,
nargs
=
'?'
,
help
=
'The id of the experiment you want to resume'
)
parser_resume
.
add_argument
(
'id'
,
help
=
'The id of the experiment you want to resume'
)
parser_resume
.
add_argument
(
'--port'
,
'-p'
,
default
=
DEFAULT_REST_PORT
,
dest
=
'port'
,
type
=
int
,
help
=
'the port of restful server'
)
parser_resume
.
add_argument
(
'--port'
,
'-p'
,
default
=
DEFAULT_REST_PORT
,
dest
=
'port'
,
type
=
int
,
help
=
'the port of restful server'
)
parser_resume
.
add_argument
(
'--debug'
,
'-d'
,
action
=
'store_true'
,
help
=
' set debug mode'
)
parser_resume
.
add_argument
(
'--debug'
,
'-d'
,
action
=
'store_true'
,
help
=
' set debug mode'
)
parser_resume
.
add_argument
(
'--foreground'
,
'-f'
,
action
=
'store_true'
,
help
=
' set foreground mode, print log content to terminal'
)
parser_resume
.
add_argument
(
'--foreground'
,
'-f'
,
action
=
'store_true'
,
help
=
' set foreground mode, print log content to terminal'
)
...
@@ -72,7 +72,7 @@ def parse_args():
...
@@ -72,7 +72,7 @@ def parse_args():
# parse view command
# parse view command
parser_view
=
subparsers
.
add_parser
(
'view'
,
help
=
'view a stopped experiment'
)
parser_view
=
subparsers
.
add_parser
(
'view'
,
help
=
'view a stopped experiment'
)
parser_view
.
add_argument
(
'id'
,
nargs
=
'?'
,
help
=
'The id of the experiment you want to view'
)
parser_view
.
add_argument
(
'id'
,
help
=
'The id of the experiment you want to view'
)
parser_view
.
add_argument
(
'--port'
,
'-p'
,
default
=
DEFAULT_REST_PORT
,
dest
=
'port'
,
type
=
int
,
help
=
'the port of restful server'
)
parser_view
.
add_argument
(
'--port'
,
'-p'
,
default
=
DEFAULT_REST_PORT
,
dest
=
'port'
,
type
=
int
,
help
=
'the port of restful server'
)
parser_view
.
add_argument
(
'--experiment_dir'
,
'-e'
,
help
=
'view experiment from external folder, specify the full path of '
\
parser_view
.
add_argument
(
'--experiment_dir'
,
'-e'
,
help
=
'view experiment from external folder, specify the full path of '
\
'experiment folder'
)
'experiment folder'
)
...
...
test/config/integration_tests.yml
View file @
d5857823
...
@@ -199,9 +199,6 @@ testCases:
...
@@ -199,9 +199,6 @@ testCases:
launchCommand
:
nnictl view $resumeExpId
launchCommand
:
nnictl view $resumeExpId
experimentStatusCheck
:
False
experimentStatusCheck
:
False
-
name
:
multi-thread
configFile
:
test/config/multi_thread/config.yml
#########################################################################
#########################################################################
# nni assessor test
# nni assessor test
...
...
test/config/integration_tests_tf2.yml
View file @
d5857823
...
@@ -132,9 +132,6 @@ testCases:
...
@@ -132,9 +132,6 @@ testCases:
launchCommand
:
nnictl view $resumeExpId
launchCommand
:
nnictl view $resumeExpId
experimentStatusCheck
:
False
experimentStatusCheck
:
False
-
name
:
multi-thread
configFile
:
test/config/multi_thread/config.yml
#########################################################################
#########################################################################
# nni assessor test
# nni assessor test
#########################################################################
#########################################################################
...
...
test/config/pr_tests.yml
View file @
d5857823
...
@@ -42,9 +42,6 @@ testCases:
...
@@ -42,9 +42,6 @@ testCases:
kwargs
:
kwargs
:
expected_result_file
:
expected_metrics_dict.json
expected_result_file
:
expected_metrics_dict.json
-
name
:
multi-thread
configFile
:
test/config/multi_thread/config.yml
#########################################################################
#########################################################################
# nni assessor test
# nni assessor test
#########################################################################
#########################################################################
...
...
test/ut/experiment/assets/config.yaml
0 → 100644
View file @
d5857823
experimentName
:
test case
searchSpaceFile
:
search_space.json
trialCommand
:
python main.py
trialCodeDirectory
:
../assets
trialConcurrency
:
2
trialGpuNumber
:
1
maxExperimentDuration
:
1.5h
maxTrialNumber
:
10
maxTrialDuration
:
60
nniManagerIp
:
1.2.3.4
debug
:
true
logLevel
:
warning
tunerGpuIndices
:
0
assessor
:
name
:
assess
advisor
:
className
:
Advisor
codeDirectory
:
.
classArgs
:
{
random_seed
:
0
}
trainingService
:
platform
:
local
useActiveGpu
:
false
maxTrialNumberPerGpu
:
2
gpuIndices
:
1,2
reuseMode
:
true
sharedStorage
:
storageType
:
NFS
localMountPoint
:
.
# git cannot commit empty dir, so just use this
remoteMountPoint
:
/tmp
localMounted
:
usermount
nfsServer
:
nfs.test.case
exportedDirectory
:
root
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment