Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
d5857823
"...include/git@developer.sourcefind.cn:gaoqiong/migraphx.git" did not exist on "afa4a8333e1fee8ec57504bb5cf5f1618a805809"
Unverified
Commit
d5857823
authored
Dec 20, 2021
by
liuzhe-lz
Committed by
GitHub
Dec 20, 2021
Browse files
Config refactor (#4370)
parent
cb090e8c
Changes
70
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
826 additions
and
622 deletions
+826
-622
dependencies/required.txt
dependencies/required.txt
+1
-0
docs/en_US/Tutorial/HowToLaunchFromPython.rst
docs/en_US/Tutorial/HowToLaunchFromPython.rst
+0
-2
nni/__main__.py
nni/__main__.py
+1
-2
nni/experiment/__init__.py
nni/experiment/__init__.py
+1
-1
nni/experiment/config/__init__.py
nni/experiment/config/__init__.py
+3
-9
nni/experiment/config/adl.py
nni/experiment/config/adl.py
+0
-17
nni/experiment/config/algorithm.py
nni/experiment/config/algorithm.py
+68
-0
nni/experiment/config/base.py
nni/experiment/config/base.py
+234
-120
nni/experiment/config/common.py
nni/experiment/config/common.py
+0
-208
nni/experiment/config/convert.py
nni/experiment/config/convert.py
+170
-159
nni/experiment/config/exp_config.py
nni/experiment/config/exp_config.py
+154
-0
nni/experiment/config/local.py
nni/experiment/config/local.py
+0
-28
nni/experiment/config/remote.py
nni/experiment/config/remote.py
+0
-63
nni/experiment/config/shared_storage.py
nni/experiment/config/shared_storage.py
+14
-2
nni/experiment/config/training_service.py
nni/experiment/config/training_service.py
+58
-0
nni/experiment/config/training_services/__init__.py
nni/experiment/config/training_services/__init__.py
+11
-0
nni/experiment/config/training_services/aml.py
nni/experiment/config/training_services/aml.py
+15
-6
nni/experiment/config/training_services/dlc.py
nni/experiment/config/training_services/dlc.py
+1
-5
nni/experiment/config/training_services/frameworkcontroller.py
...xperiment/config/training_services/frameworkcontroller.py
+48
-0
nni/experiment/config/training_services/k8s_storage.py
nni/experiment/config/training_services/k8s_storage.py
+47
-0
No files found.
dependencies/required.txt
View file @
d5857823
...
@@ -6,6 +6,7 @@ pyyaml >= 5.4
...
@@ -6,6 +6,7 @@ pyyaml >= 5.4
requests
requests
responses
responses
schema
schema
typeguard
PythonWebHDFS
PythonWebHDFS
colorama
colorama
scikit-learn >= 0.24.1 ; python_version >= "3.7"
scikit-learn >= 0.24.1 ; python_version >= "3.7"
...
...
docs/en_US/Tutorial/HowToLaunchFromPython.rst
View file @
d5857823
...
@@ -314,6 +314,4 @@ Azure Blob Config
...
@@ -314,6 +314,4 @@ Azure Blob Config
.. autoattribute:: nni.experiment.config.AzureBlobConfig.storage_account_key
.. autoattribute:: nni.experiment.config.AzureBlobConfig.storage_account_key
.. autoattribute:: nni.experiment.config.AzureBlobConfig.resource_group_name
.. autoattribute:: nni.experiment.config.AzureBlobConfig.container_name
.. autoattribute:: nni.experiment.config.AzureBlobConfig.container_name
nni/__main__.py
View file @
d5857823
...
@@ -33,11 +33,10 @@ def main():
...
@@ -33,11 +33,10 @@ def main():
enable_multi_thread
()
enable_multi_thread
()
if
'trainingServicePlatform'
in
exp_params
:
# config schema is v1
if
'trainingServicePlatform'
in
exp_params
:
# config schema is v1
from
types
import
SimpleNamespace
from
.experiment.config.convert
import
convert_algo
from
.experiment.config.convert
import
convert_algo
for
algo_type
in
[
'tuner'
,
'assessor'
,
'advisor'
]:
for
algo_type
in
[
'tuner'
,
'assessor'
,
'advisor'
]:
if
algo_type
in
exp_params
:
if
algo_type
in
exp_params
:
exp_params
[
algo_type
]
=
convert_algo
(
algo_type
,
exp_params
,
SimpleNamespace
()).
json
(
)
exp_params
[
algo_type
]
=
convert_algo
(
algo_type
,
exp_params
[
algo_type
]
)
if
exp_params
.
get
(
'advisor'
)
is
not
None
:
if
exp_params
.
get
(
'advisor'
)
is
not
None
:
# advisor is enabled and starts to run
# advisor is enabled and starts to run
...
...
nni/experiment/__init__.py
View file @
d5857823
...
@@ -2,5 +2,5 @@
...
@@ -2,5 +2,5 @@
# Licensed under the MIT license.
# Licensed under the MIT license.
from
.config
import
*
from
.config
import
*
from
.experiment
import
Experiment
from
.experiment
import
Experiment
,
RunMode
from
.data
import
*
from
.data
import
*
nni/experiment/config/__init__.py
View file @
d5857823
# Copyright (c) Microsoft Corporation.
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# Licensed under the MIT license.
from
.common
import
*
from
.exp_config
import
ExperimentConfig
from
.local
import
*
from
.algorithm
import
AlgorithmConfig
,
CustomAlgorithmConfig
from
.remote
import
*
from
.training_services
import
*
from
.openpai
import
*
from
.aml
import
*
from
.kubeflow
import
*
from
.frameworkcontroller
import
*
from
.adl
import
*
from
.dlc
import
*
from
.shared_storage
import
*
from
.shared_storage
import
*
nni/experiment/config/adl.py
deleted
100644 → 0
View file @
cb090e8c
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from
dataclasses
import
dataclass
from
.common
import
TrainingServiceConfig
__all__
=
[
'AdlConfig'
]
@
dataclass
(
init
=
False
)
class
AdlConfig
(
TrainingServiceConfig
):
platform
:
str
=
'adl'
docker_image
:
str
=
'msranni/nni:latest'
_validation_rules
=
{
'platform'
:
lambda
value
:
(
value
==
'adl'
,
'cannot be modified'
)
}
nni/experiment/config/algorithm.py
0 → 100644
View file @
d5857823
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
"""
Config classes for tuner/assessor/advisor algorithms.
Use ``AlgorithmConfig`` to specify a built-in algorithm;
use ``CustomAlgorithmConfig`` to specify a custom algorithm.
Check the reference_ for explaination of each field.
You may also want to check `tuner's overview`_.
.. _reference: https://nni.readthedocs.io/en/stable/reference/experiment_config.html
.. _tuner's overview: https://nni.readthedocs.io/en/stable/Tuner/BuiltinTuner.html
"""
__all__
=
[
'AlgorithmConfig'
,
'CustomAlgorithmConfig'
]
from
dataclasses
import
dataclass
from
pathlib
import
Path
from
typing
import
Any
,
Dict
,
Optional
from
.base
import
ConfigBase
from
.utils
import
PathLike
@
dataclass
(
init
=
False
)
class
_AlgorithmConfig
(
ConfigBase
):
"""
Common base class for ``AlgorithmConfig`` and ``CustomAlgorithmConfig``.
It's a "union set" of 2 derived classes. So users can use it as either one.
"""
name
:
Optional
[
str
]
=
None
class_name
:
Optional
[
str
]
=
None
code_directory
:
Optional
[
PathLike
]
=
None
class_args
:
Optional
[
Dict
[
str
,
Any
]]
=
None
def
_validate_canonical
(
self
):
super
().
_validate_canonical
()
if
self
.
class_name
is
None
:
# assume it's built-in algorithm by default
assert
self
.
name
assert
self
.
code_directory
is
None
else
:
# custom algorithm
assert
self
.
name
is
None
assert
self
.
class_name
if
not
Path
(
self
.
code_directory
).
is_dir
():
raise
ValueError
(
f
'CustomAlgorithmConfig: code_directory "
{
self
.
code_directory
}
" is not a directory'
)
@
dataclass
(
init
=
False
)
class
AlgorithmConfig
(
_AlgorithmConfig
):
"""
Configuration for built-in algorithm.
"""
name
:
str
class_args
:
Optional
[
Dict
[
str
,
Any
]]
=
None
@
dataclass
(
init
=
False
)
class
CustomAlgorithmConfig
(
_AlgorithmConfig
):
"""
Configuration for custom algorithm.
"""
class_name
:
str
code_directory
:
Optional
[
PathLike
]
=
'.'
class_args
:
Optional
[
Dict
[
str
,
Any
]]
=
None
nni/experiment/config/base.py
View file @
d5857823
# Copyright (c) Microsoft Corporation.
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# Licensed under the MIT license.
"""
``ConfigBase`` class. Nothing else.
Docstrings in this file are mainly for NNI contributors instead of end users.
"""
__all__
=
[
'ConfigBase'
]
import
copy
import
copy
import
dataclasses
import
dataclasses
from
pathlib
import
Path
from
pathlib
import
Path
from
typing
import
Any
,
Dict
,
Optional
,
Type
,
TypeVar
import
yaml
import
yaml
from
.
import
util
from
.
import
util
s
__all__
=
[
'ConfigBase'
,
'PathLike'
]
class
ConfigBase
:
"""
The abstract base class of experiment config classes.
T
=
TypeVar
(
'T'
,
bound
=
'ConfigBase'
)
A config class should be a type-hinted dataclass inheriting ``ConfigBase``.
Or for a training service config class, it can inherit ``TrainingServiceConfig``.
PathLike
=
util
.
PathLike
.. code-block:: python
def
_is_missing
(
obj
:
Any
)
->
bool
:
@dataclass(init=False)
return
isinstance
(
obj
,
type
(
dataclasses
.
MISSING
))
class ExperimentConfig(ConfigBase):
name: Optional[str]
...
class
ConfigBase
:
Subclasses are suggested to override ``_canonicalize()`` and ``_validate_canonical()`` methods.
"""
Base class of config classes.
Users can create a config object with constructor or ``ConfigBase.load()``,
Subclass may override `_canonical_rules` and `_validation_rules`,
validate its legality with ``ConfigBase.validate()``,
and `validate()` if the logic is complex.
and finally convert it to the format accepted by NNI manager with ``ConfigBase.json()``.
Example usage:
.. code-block:: python
# when using Python API
config1 = ExperimentConfig(trialCommand='...', trialConcurrency=1, ...)
config1.validate()
print(config1.json())
# when using config file
config2 = ExperimentConfig.load('examples/config.yml')
config2.validate()
print(config2.json())
Config objects will remember where they are loaded; therefore relative paths can be resolved smartly.
If a config object is created with constructor, the base path will be current working directory.
If it is loaded with ``ConfigBase.load(path)``, the base path will be ``path``'s parent.
"""
"""
# Rules to convert field value to canonical format.
def
__init__
(
self
,
**
kwargs
):
# The key is field name.
# The value is callable `value -> canonical_value`
# It is not type-hinted so dataclass won't treat it as field
_canonical_rules
=
{}
# type: ignore
# Rules to validate field value.
# The key is field name.
# The value is callable `value -> valid` or `value -> (valid, error_message)`
# The rule will be called with canonical format and is only called when `value` is not None.
# `error_message` is used when `valid` is False.
# It will be prepended with class name and field name in exception message.
_validation_rules
=
{}
# type: ignore
def
__init__
(
self
,
*
,
_base_path
:
Optional
[
Path
]
=
None
,
**
kwargs
):
"""
"""
Initialize a config object and set some fields.
There are two common ways to use the constructor,
Name of keyword arguments can either be snake_case or camelCase.
directly writing Python code and unpacking from JSON(YAML) object:
They will be converted to snake_case automatically.
If a field is missing and don't have default value, it will be set to `dataclasses.MISSING`.
.. code-block:: python
config1 = AlgorithmConfig(name='TPE', class_args={'optimize_mode': 'maximize'})
json = {'name': 'TPE', 'classArgs': {'optimize_mode': 'maximize'}}
config2 = AlgorithmConfig(**json)
If the config class has fields whose type is another config class, or list of another config class,
they will recursively load dict values.
Because JSON objects can use "camelCase" for field names,
cases and underscores in ``kwargs`` keys are ignored in this constructor.
For example if a config class has a field ``hello_world``,
then using ``hello_world=1``, ``helloWorld=1``, and ``_HELLOWORLD_=1`` in constructor
will all assign to the same field.
If ``kwargs`` contain extra keys, a `ValueError` will be raised.
If ``kwargs`` do not have enough key, missing fields are silently set to `MISSING()`.
You can use ``utils.is_missing()`` to check them.
"""
"""
if
'basepath'
in
kwargs
:
self
.
_base_path
=
utils
.
get_base_path
()
_base_path
=
kwargs
.
pop
(
'basepath'
)
args
=
{
utils
.
case_insensitive
(
key
):
value
for
key
,
value
in
kwargs
.
items
()}
kwargs
=
{
util
.
case_insensitive
(
key
):
value
for
key
,
value
in
kwargs
.
items
()}
if
_base_path
is
None
:
_base_path
=
Path
()
for
field
in
dataclasses
.
fields
(
self
):
for
field
in
dataclasses
.
fields
(
self
):
value
=
kwargs
.
pop
(
util
.
case_insensitive
(
field
.
name
),
field
.
default
)
value
=
args
.
pop
(
utils
.
case_insensitive
(
field
.
name
),
field
.
default
)
if
value
is
not
None
and
not
_is_missing
(
value
):
# relative paths loaded from config file are not relative to pwd
if
'Path'
in
str
(
field
.
type
):
value
=
Path
(
value
).
expanduser
()
if
not
value
.
is_absolute
():
value
=
_base_path
/
value
setattr
(
self
,
field
.
name
,
value
)
setattr
(
self
,
field
.
name
,
value
)
if
kwargs
:
if
args
:
# maybe a key is misspelled
cls
=
type
(
self
).
__name__
class_name
=
type
(
self
).
__name__
fields
=
', '
.
join
(
kwargs
.
keys
())
fields
=
', '
.
join
(
args
.
keys
())
raise
ValueError
(
f
'
{
cls
}
: Unrecognized fields
{
fields
}
'
)
raise
ValueError
(
f
'
{
class_name
}
does not have field(s)
{
fields
}
'
)
# try to unpack nested config
for
field
in
dataclasses
.
fields
(
self
):
value
=
getattr
(
self
,
field
.
name
)
if
utils
.
is_instance
(
value
,
field
.
type
):
continue
# already accepted by subclass, don't touch it
if
isinstance
(
value
,
dict
):
config
=
utils
.
guess_config_type
(
value
,
field
.
type
)
if
config
is
not
None
:
setattr
(
self
,
field
.
name
,
config
)
elif
isinstance
(
value
,
list
)
and
value
and
isinstance
(
value
[
0
],
dict
):
configs
=
utils
.
guess_list_config_type
(
value
,
field
.
type
)
if
configs
:
setattr
(
self
,
field
.
name
,
configs
)
@
classmethod
@
classmethod
def
load
(
cls
:
Type
[
T
],
path
:
PathLike
)
->
T
:
def
load
(
cls
,
path
)
:
"""
"""
Load config from YAML (or JSON) file.
Load a YAML config file from file system.
Keys in YAML file can either be camelCase or snake_case.
Since YAML is a superset of JSON, it can also load JSON files.
This method raises exception if:
- The file is not available
- The file content is not valid YAML
- Top level value of the YAML is not object
- The YAML contains not supported fields
It does not raise exception when the YAML misses fields or contains bad fields.
Parameters
----------
path : PathLike
Path of the config file.
Returns
-------
cls
An object of ConfigBase subclass.
"""
"""
data
=
yaml
.
safe_load
(
open
(
path
))
with
open
(
path
)
as
yaml_file
:
data
=
yaml
.
safe_load
(
yaml_file
)
if
not
isinstance
(
data
,
dict
):
if
not
isinstance
(
data
,
dict
):
raise
ValueError
(
f
'Content of config file
{
path
}
is not a dict/object'
)
raise
ValueError
(
f
'Conent of config file
{
path
}
is not a dict/object'
)
return
cls
(
**
data
,
_base_path
=
Path
(
path
).
parent
)
utils
.
set_base_path
(
Path
(
path
).
parent
)
config
=
cls
(
**
data
)
utils
.
unset_base_path
()
return
config
def
json
(
self
)
->
Dict
[
str
,
Any
]
:
def
canonical_copy
(
self
)
:
"""
"""
Convert config to JSON object.
Create a canonicalized copy of the config, and validate it.
The keys of returned object will be camelCase.
This function is mainly used internally by NNI.
Returns
-------
type(self)
A deep copy.
"""
canon
=
copy
.
deepcopy
(
self
)
canon
.
_canonicalize
([])
canon
.
_validate_canonical
()
return
canon
def
validate
(
self
):
"""
Validate legality of the config object. Raise exception if any error occurred.
This function does **not** return truth value. Do not write ``if config.validate()``.
Returns
-------
None
"""
"""
self
.
validate
()
self
.
canonical_copy
()
return
dataclasses
.
asdict
(
self
.
canonical
(),
dict_factory
=
lambda
items
:
dict
((
util
.
camel_case
(
k
),
v
)
for
k
,
v
in
items
if
v
is
not
None
)
)
def
canonical
(
self
:
T
)
->
T
:
def
json
(
self
)
:
"""
"""
Returns a deep copy, where the fields supporting multiple formats are converted to the canonical format.
Convert the config to JSON object (not JSON string).
Noticeably, relative path may be converted to absolute path.
In current implementation ``json()`` will invoke ``validate()``, but this might change in future version.
It is recommended to call ``validate()`` before ``json()`` for now.
Returns
-------
dict
JSON object.
"""
"""
ret
=
copy
.
deepcopy
(
self
)
canon
=
self
.
canonical_copy
()
for
field
in
dataclasses
.
fields
(
ret
):
return
dataclasses
.
asdict
(
canon
,
dict_factory
=
_dict_factory
)
# this is recursive
key
,
value
=
field
.
name
,
getattr
(
ret
,
field
.
name
)
rule
=
ret
.
_canonical_rules
.
get
(
key
)
def
_canonicalize
(
self
,
parents
):
if
rule
is
not
None
:
setattr
(
ret
,
key
,
rule
(
value
))
elif
isinstance
(
value
,
ConfigBase
):
setattr
(
ret
,
key
,
value
.
canonical
())
# value will be copied twice, should not be a performance issue anyway
elif
isinstance
(
value
,
Path
):
setattr
(
ret
,
key
,
str
(
value
))
return
ret
def
validate
(
self
)
->
None
:
"""
"""
Validate the config object and raise Exception if it's ill-formed.
The config schema for end users is more flexible than the format NNI manager accepts.
This method convert a config object to the constrained format accepted by NNI manager.
The default implementation will:
1. Resolve all ``PathLike`` fields to absolute path
2. Call ``_canonicalize()`` on all children config objects, including those inside list and dict
Subclasses are recommended to call ``super()._canonicalize(parents)`` at the end of their overrided version.
Parameters
----------
parents : list[ConfigBase]
The upper level config objects.
For example local training service's ``trialGpuNumber`` will be copied from top level when not set,
in this case it will be invoked like ``localConfig._canonicalize([experimentConfig])``.
"""
"""
class_name
=
type
(
self
).
__name__
for
field
in
dataclasses
.
fields
(
self
):
config
=
self
.
canonical
()
value
=
getattr
(
self
,
field
.
name
)
if
isinstance
(
value
,
(
Path
,
str
))
and
utils
.
is_path_like
(
field
.
type
):
for
field
in
dataclasses
.
fields
(
config
):
setattr
(
self
,
field
.
name
,
utils
.
resolve_path
(
value
,
self
.
_base_path
))
key
,
value
=
field
.
name
,
getattr
(
config
,
field
.
name
)
else
:
_recursive_canonicalize_child
(
value
,
[
self
]
+
parents
)
# check existence
if
_is_missing
(
value
):
def
_validate_canonical
(
self
):
raise
ValueError
(
f
'
{
class_name
}
:
{
key
}
is not set'
)
"""
Validate legality of a canonical config object. It's caller's responsibility to ensure the config is canonical.
# check type (TODO)
type_name
=
str
(
field
.
type
).
replace
(
'typing.'
,
''
)
Raise exception if any problem found. This function does **not** return truth value.
optional
=
any
([
type_name
.
startswith
(
'Optional['
),
The default implementation will:
type_name
.
startswith
(
'Union['
)
and
'None'
in
type_name
,
type_name
==
'Any'
1. Validate that all fields match their type hint
])
2. Call ``_validate_canonical()`` on children config objects, including those inside list and dict
if
value
is
None
:
if
optional
:
Subclasses are recommended to to call ``super()._validate_canonical()``.
continue
"""
else
:
utils
.
validate_type
(
self
)
raise
ValueError
(
f
'
{
class_name
}
:
{
key
}
cannot be None'
)
for
field
in
dataclasses
.
fields
(
self
):
value
=
getattr
(
self
,
field
.
name
)
# check value
_recursive_validate_child
(
value
)
rule
=
config
.
_validation_rules
.
get
(
key
)
if
rule
is
not
None
:
def
__setattr__
(
self
,
name
,
value
):
try
:
if
hasattr
(
self
,
name
)
or
name
.
startswith
(
'_'
):
result
=
rule
(
value
)
super
().
__setattr__
(
name
,
value
)
except
Exception
:
return
raise
ValueError
(
f
'
{
class_name
}
:
{
key
}
has bad value
{
repr
(
value
)
}
'
)
if
name
in
[
field
.
name
for
field
in
dataclasses
.
fields
(
self
)]:
# might happend during __init__
super
().
__setattr__
(
name
,
value
)
if
isinstance
(
result
,
bool
):
return
if
not
result
:
raise
AttributeError
(
f
'
{
type
(
self
).
__name__
}
does not have field
{
name
}
'
)
raise
ValueError
(
f
'
{
class_name
}
:
{
key
}
(
{
repr
(
value
)
}
) is out of range'
)
else
:
def
_dict_factory
(
items
):
if
not
result
[
0
]:
ret
=
{}
raise
ValueError
(
f
'
{
class_name
}
:
{
key
}
{
result
[
1
]
}
'
)
for
key
,
value
in
items
:
if
value
is
not
None
:
# check nested config
k
=
utils
.
camel_case
(
key
)
if
isinstance
(
value
,
ConfigBase
):
v
=
str
(
value
)
if
isinstance
(
value
,
Path
)
else
value
value
.
validate
()
ret
[
k
]
=
v
return
ret
def
_recursive_canonicalize_child
(
child
,
parents
):
if
isinstance
(
child
,
ConfigBase
):
child
.
_canonicalize
(
parents
)
elif
isinstance
(
child
,
list
):
for
item
in
child
:
_recursive_canonicalize_child
(
item
,
parents
)
elif
isinstance
(
child
,
dict
):
for
item
in
child
.
values
():
_recursive_canonicalize_child
(
item
,
parents
)
def
_recursive_validate_child
(
child
):
if
isinstance
(
child
,
ConfigBase
):
child
.
_validate_canonical
()
elif
isinstance
(
child
,
list
):
for
item
in
child
:
_recursive_validate_child
(
item
)
elif
isinstance
(
child
,
dict
):
for
item
in
child
.
values
():
_recursive_validate_child
(
item
)
nni/experiment/config/common.py
deleted
100644 → 0
View file @
cb090e8c
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from
dataclasses
import
dataclass
from
pathlib
import
Path
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Union
import
yaml
from
.base
import
ConfigBase
,
PathLike
from
.
import
util
__all__
=
[
'ExperimentConfig'
,
'AlgorithmConfig'
,
'CustomAlgorithmConfig'
,
'TrainingServiceConfig'
,
]
@
dataclass
(
init
=
False
)
class
_AlgorithmConfig
(
ConfigBase
):
name
:
Optional
[
str
]
=
None
class_name
:
Optional
[
str
]
=
None
code_directory
:
Optional
[
PathLike
]
=
None
class_args
:
Optional
[
Dict
[
str
,
Any
]]
=
None
def
validate
(
self
):
super
().
validate
()
_validate_algo
(
self
)
_canonical_rules
=
{
'code_directory'
:
util
.
canonical_path
}
@
dataclass
(
init
=
False
)
class
AlgorithmConfig
(
_AlgorithmConfig
):
name
:
str
class_args
:
Optional
[
Dict
[
str
,
Any
]]
=
None
@
dataclass
(
init
=
False
)
class
CustomAlgorithmConfig
(
_AlgorithmConfig
):
class_name
:
str
code_directory
:
Optional
[
PathLike
]
=
'.'
class_args
:
Optional
[
Dict
[
str
,
Any
]]
=
None
class
TrainingServiceConfig
(
ConfigBase
):
platform
:
str
@
dataclass
(
init
=
False
)
class
SharedStorageConfig
(
ConfigBase
):
storage_type
:
str
local_mount_point
:
PathLike
remote_mount_point
:
str
local_mounted
:
str
storage_account_name
:
Optional
[
str
]
=
None
storage_account_key
:
Optional
[
str
]
=
None
container_name
:
Optional
[
str
]
=
None
nfs_server
:
Optional
[
str
]
=
None
exported_directory
:
Optional
[
str
]
=
None
def
__init__
(
self
,
*
,
_base_path
:
Optional
[
Path
]
=
None
,
**
kwargs
):
kwargs
=
{
util
.
case_insensitive
(
key
):
value
for
key
,
value
in
kwargs
.
items
()}
if
'localmountpoint'
in
kwargs
:
kwargs
[
'localmountpoint'
]
=
Path
(
kwargs
[
'localmountpoint'
]).
expanduser
()
if
not
kwargs
[
'localmountpoint'
].
is_absolute
():
raise
ValueError
(
'localMountPoint can only be set as an absolute path.'
)
super
().
__init__
(
_base_path
=
_base_path
,
**
kwargs
)
@
dataclass
(
init
=
False
)
class
ExperimentConfig
(
ConfigBase
):
experiment_name
:
Optional
[
str
]
=
None
search_space_file
:
Optional
[
PathLike
]
=
None
search_space
:
Any
=
None
trial_command
:
str
trial_code_directory
:
PathLike
=
'.'
trial_concurrency
:
int
trial_gpu_number
:
Optional
[
int
]
=
None
# TODO: in openpai cannot be None
max_experiment_duration
:
Optional
[
str
]
=
None
max_trial_number
:
Optional
[
int
]
=
None
max_trial_duration
:
Optional
[
int
]
=
None
nni_manager_ip
:
Optional
[
str
]
=
None
use_annotation
:
bool
=
False
debug
:
bool
=
False
log_level
:
Optional
[
str
]
=
None
experiment_working_directory
:
PathLike
=
'~/nni-experiments'
tuner_gpu_indices
:
Union
[
List
[
int
],
str
,
int
,
None
]
=
None
tuner
:
Optional
[
_AlgorithmConfig
]
=
None
assessor
:
Optional
[
_AlgorithmConfig
]
=
None
advisor
:
Optional
[
_AlgorithmConfig
]
=
None
training_service
:
Union
[
TrainingServiceConfig
,
List
[
TrainingServiceConfig
]]
shared_storage
:
Optional
[
SharedStorageConfig
]
=
None
_deprecated
:
Optional
[
Dict
[
str
,
Any
]]
=
None
def
__init__
(
self
,
training_service_platform
:
Optional
[
Union
[
str
,
List
[
str
]]]
=
None
,
**
kwargs
):
base_path
=
kwargs
.
pop
(
'_base_path'
,
None
)
kwargs
=
util
.
case_insensitive
(
kwargs
)
if
training_service_platform
is
not
None
:
assert
'trainingservice'
not
in
kwargs
kwargs
[
'trainingservice'
]
=
util
.
training_service_config_factory
(
platform
=
training_service_platform
,
base_path
=
base_path
)
elif
isinstance
(
kwargs
.
get
(
'trainingservice'
),
(
dict
,
list
)):
# dict means a single training service
# list means hybrid training service
kwargs
[
'trainingservice'
]
=
util
.
training_service_config_factory
(
config
=
kwargs
[
'trainingservice'
],
base_path
=
base_path
)
else
:
raise
RuntimeError
(
'Unsupported Training service configuration!'
)
super
().
__init__
(
_base_path
=
base_path
,
**
kwargs
)
for
algo_type
in
[
'tuner'
,
'assessor'
,
'advisor'
]:
if
isinstance
(
kwargs
.
get
(
algo_type
),
dict
):
setattr
(
self
,
algo_type
,
_AlgorithmConfig
(
**
kwargs
.
pop
(
algo_type
)))
if
isinstance
(
kwargs
.
get
(
'sharedstorage'
),
dict
):
setattr
(
self
,
'shared_storage'
,
SharedStorageConfig
(
_base_path
=
base_path
,
**
kwargs
.
pop
(
'sharedstorage'
)))
def
canonical
(
self
):
ret
=
super
().
canonical
()
if
isinstance
(
ret
.
training_service
,
list
):
for
i
,
ts
in
enumerate
(
ret
.
training_service
):
ret
.
training_service
[
i
]
=
ts
.
canonical
()
return
ret
def
validate
(
self
,
initialized_tuner
:
bool
=
False
)
->
None
:
super
().
validate
()
if
initialized_tuner
:
_validate_for_exp
(
self
.
canonical
())
else
:
_validate_for_nnictl
(
self
.
canonical
())
if
self
.
trial_gpu_number
and
hasattr
(
self
.
training_service
,
'use_active_gpu'
):
if
self
.
training_service
.
use_active_gpu
is
None
:
raise
ValueError
(
'Please set "use_active_gpu"'
)
def
json
(
self
)
->
Dict
[
str
,
Any
]:
obj
=
super
().
json
()
if
obj
.
get
(
'searchSpaceFile'
):
obj
[
'searchSpace'
]
=
yaml
.
safe_load
(
open
(
obj
.
pop
(
'searchSpaceFile'
)))
return
obj
## End of public API ##
@
property
def
_canonical_rules
(
self
):
return
_canonical_rules
@
property
def
_validation_rules
(
self
):
return
_validation_rules
_canonical_rules
=
{
'search_space_file'
:
util
.
canonical_path
,
'trial_code_directory'
:
util
.
canonical_path
,
'max_experiment_duration'
:
lambda
value
:
f
'
{
util
.
parse_time
(
value
)
}
s'
if
value
is
not
None
else
None
,
'experiment_working_directory'
:
util
.
canonical_path
,
'tuner_gpu_indices'
:
util
.
canonical_gpu_indices
,
'tuner'
:
lambda
config
:
None
if
config
is
None
or
config
.
name
==
'_none_'
else
config
.
canonical
(),
'assessor'
:
lambda
config
:
None
if
config
is
None
or
config
.
name
==
'_none_'
else
config
.
canonical
(),
'advisor'
:
lambda
config
:
None
if
config
is
None
or
config
.
name
==
'_none_'
else
config
.
canonical
(),
}
_validation_rules
=
{
'search_space_file'
:
lambda
value
:
(
Path
(
value
).
is_file
(),
f
'"
{
value
}
" does not exist or is not regular file'
),
'trial_code_directory'
:
lambda
value
:
(
Path
(
value
).
is_dir
(),
f
'"
{
value
}
" does not exist or is not directory'
),
'trial_concurrency'
:
lambda
value
:
value
>
0
,
'trial_gpu_number'
:
lambda
value
:
value
>=
0
,
'max_experiment_duration'
:
lambda
value
:
util
.
parse_time
(
value
)
>
0
,
'max_trial_number'
:
lambda
value
:
value
>
0
,
'max_trial_duration'
:
lambda
value
:
util
.
parse_time
(
value
)
>
0
,
'log_level'
:
lambda
value
:
value
in
[
"trace"
,
"debug"
,
"info"
,
"warning"
,
"error"
,
"fatal"
],
'tuner_gpu_indices'
:
lambda
value
:
all
(
i
>=
0
for
i
in
value
)
and
len
(
value
)
==
len
(
set
(
value
)),
'training_service'
:
lambda
value
:
(
type
(
value
)
is
not
TrainingServiceConfig
,
'cannot be abstract base class'
)
}
def
_validate_for_exp
(
config
:
ExperimentConfig
)
->
None
:
# validate experiment for nni.Experiment, where tuner is already initialized outside
if
config
.
use_annotation
:
raise
ValueError
(
'ExperimentConfig: annotation is not supported in this mode'
)
if
util
.
count
(
config
.
search_space
,
config
.
search_space_file
)
!=
1
:
raise
ValueError
(
'ExperimentConfig: search_space and search_space_file must be set one'
)
if
util
.
count
(
config
.
tuner
,
config
.
assessor
,
config
.
advisor
)
!=
0
:
raise
ValueError
(
'ExperimentConfig: tuner, assessor, and advisor must not be set in for this mode'
)
if
config
.
tuner_gpu_indices
is
not
None
:
raise
ValueError
(
'ExperimentConfig: tuner_gpu_indices is not supported in this mode'
)
def
_validate_for_nnictl
(
config
:
ExperimentConfig
)
->
None
:
# validate experiment for normal launching approach
if
config
.
use_annotation
:
if
util
.
count
(
config
.
search_space
,
config
.
search_space_file
)
!=
0
:
raise
ValueError
(
'ExperimentConfig: search_space and search_space_file must not be set with annotationn'
)
else
:
if
util
.
count
(
config
.
search_space
,
config
.
search_space_file
)
!=
1
:
raise
ValueError
(
'ExperimentConfig: search_space and search_space_file must be set one'
)
if
util
.
count
(
config
.
tuner
,
config
.
advisor
)
!=
1
:
raise
ValueError
(
'ExperimentConfig: tuner and advisor must be set one'
)
def
_validate_algo
(
algo
:
AlgorithmConfig
)
->
None
:
if
algo
.
name
is
None
:
if
algo
.
class_name
is
None
:
raise
ValueError
(
'Missing algorithm name'
)
if
algo
.
code_directory
is
not
None
and
not
Path
(
algo
.
code_directory
).
is_dir
():
raise
ValueError
(
f
'code_directory "
{
algo
.
code_directory
}
" does not exist or is not directory'
)
else
:
if
algo
.
class_name
is
not
None
or
algo
.
code_directory
is
not
None
:
raise
ValueError
(
f
'When name is set for registered algorithm, class_name and code_directory cannot be used'
)
# TODO: verify algorithm installation and class args
nni/experiment/config/convert.py
View file @
d5857823
This diff is collapsed.
Click to expand it.
nni/experiment/config/exp_config.py
0 → 100644
View file @
d5857823
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
"""
Top level experiement configuration class, ``ExperimentConfig``.
"""
__all__
=
[
'ExperimentConfig'
]
from
dataclasses
import
dataclass
import
logging
from
typing
import
Any
,
List
,
Optional
,
Union
import
yaml
from
.algorithm
import
_AlgorithmConfig
from
.base
import
ConfigBase
from
.shared_storage
import
SharedStorageConfig
from
.training_service
import
TrainingServiceConfig
from
.
import
utils
@
dataclass
(
init
=
False
)
class
ExperimentConfig
(
ConfigBase
):
"""
Class of experiment configuration. Check the reference_ for explaination of each field.
When used in Python experiment API, it can be constructed in two favors:
1. Create an empty project then set each field
.. code-block:: python
config = ExperimentConfig('local')
config.search_space = {...}
config.tuner.name = 'random'
config.training_service.use_active_gpu = True
2. Use kwargs directly
.. code-block:: python
config = ExperimentConfig(
search_space = {...},
tuner = AlgorithmConfig(name='random'),
training_service = LocalConfig(
use_active_gpu = True
)
)
Fields commented as "training service field" acts like shortcut for all training services.
Users can either specify them here or inside training service config.
In latter case hybrid training services can have different settings.
.. _reference: https://nni.readthedocs.io/en/stable/reference/experiment_config.html
"""
experiment_name
:
Optional
[
str
]
=
None
search_space_file
:
Optional
[
utils
.
PathLike
]
=
None
search_space
:
Any
=
None
trial_command
:
Optional
[
str
]
=
None
# training service field
trial_code_directory
:
utils
.
PathLike
=
'.'
# training service field
trial_concurrency
:
int
trial_gpu_number
:
Optional
[
int
]
=
None
# training service field
max_experiment_duration
:
Union
[
str
,
int
,
None
]
=
None
max_trial_number
:
Optional
[
int
]
=
None
max_trial_duration
:
Union
[
str
,
int
,
None
]
=
None
nni_manager_ip
:
Optional
[
str
]
=
None
# training service field
use_annotation
:
bool
=
False
debug
:
bool
=
False
log_level
:
Optional
[
str
]
=
None
experiment_working_directory
:
utils
.
PathLike
=
'~/nni-experiments'
tuner_gpu_indices
:
Union
[
List
[
int
],
int
,
str
,
None
]
=
None
tuner
:
Optional
[
_AlgorithmConfig
]
=
None
assessor
:
Optional
[
_AlgorithmConfig
]
=
None
advisor
:
Optional
[
_AlgorithmConfig
]
=
None
training_service
:
Union
[
TrainingServiceConfig
,
List
[
TrainingServiceConfig
]]
shared_storage
:
Optional
[
SharedStorageConfig
]
=
None
def
__init__
(
self
,
training_service_platform
=
None
,
**
kwargs
):
super
().
__init__
(
**
kwargs
)
if
training_service_platform
is
not
None
:
# the user chose to init with `config = ExperimentConfig('local')` and set fields later
# we need to create empty training service & algorithm configs to support `config.tuner.name = 'random'`
assert
utils
.
is_missing
(
self
.
training_service
)
if
isinstance
(
training_service_platform
,
list
):
self
.
training_service
=
[
utils
.
training_service_config_factory
(
ts
)
for
ts
in
training_service_platform
]
else
:
self
.
training_service
=
utils
.
training_service_config_factory
(
training_service_platform
)
for
algo_type
in
[
'tuner'
,
'assessor'
,
'advisor'
]:
# add placeholder items, so users can write `config.tuner.name = 'random'`
if
getattr
(
self
,
algo_type
)
is
None
:
setattr
(
self
,
algo_type
,
_AlgorithmConfig
(
name
=
'_none_'
))
elif
not
utils
.
is_missing
(
self
.
training_service
):
# training service is set via json or constructor
if
isinstance
(
self
.
training_service
,
list
):
self
.
training_service
=
[
utils
.
load_training_service_config
(
ts
)
for
ts
in
self
.
training_service
]
else
:
self
.
training_service
=
utils
.
load_training_service_config
(
self
.
training_service
)
def
_canonicalize
(
self
,
_parents
):
if
self
.
log_level
is
None
:
self
.
log_level
=
'debug'
if
self
.
debug
else
'info'
self
.
tuner_gpu_indices
=
utils
.
canonical_gpu_indices
(
self
.
tuner_gpu_indices
)
for
algo_type
in
[
'tuner'
,
'assessor'
,
'advisor'
]:
algo
=
getattr
(
self
,
algo_type
)
if
algo
is
not
None
and
algo
.
name
==
'_none_'
:
setattr
(
self
,
algo_type
,
None
)
super
().
_canonicalize
([
self
])
if
self
.
nni_manager_ip
is
None
:
# show a warning if user does not set nni_manager_ip. we have many issues caused by this
# the simple detection logic won't work for hybrid, but advanced users should not need it
# ideally we should check accessibility of the ip, but it need much more work
platform
=
getattr
(
self
.
training_service
,
'platform'
)
has_ip
=
isinstance
(
getattr
(
self
.
training_service
,
'nni_manager_ip'
),
str
)
# not None or MISSING
if
platform
and
platform
!=
'local'
and
not
has_ip
:
ip
=
utils
.
get_ipv4_address
()
msg
=
f
'nni_manager_ip is not set, please make sure
{
ip
}
is accessible from training machines'
logging
.
getLogger
(
'nni.experiment.config'
).
warning
(
msg
)
def
_validate_canonical
(
self
):
super
().
_validate_canonical
()
space_cnt
=
(
self
.
search_space
is
not
None
)
+
(
self
.
search_space_file
is
not
None
)
if
self
.
use_annotation
and
space_cnt
!=
0
:
raise
ValueError
(
'ExperimentConfig: search space must not be set when annotation is enabled'
)
if
not
self
.
use_annotation
and
space_cnt
<
1
:
raise
ValueError
(
'ExperimentConfig: search_space and search_space_file must be set one'
)
if
self
.
search_space_file
is
not
None
:
with
open
(
self
.
search_space_file
)
as
ss_file
:
self
.
search_space
=
yaml
.
safe_load
(
ss_file
)
# to make the error message clear, ideally it should be:
# `if concurrency < 0: raise ValueError('trial_concurrency ({concurrency}) must greater than 0')`
# but I believe there will be hardy few users make this kind of mistakes, so let's keep it simple
assert
self
.
trial_concurrency
>
0
assert
self
.
max_experiment_duration
is
None
or
utils
.
parse_time
(
self
.
max_experiment_duration
)
>
0
assert
self
.
max_trial_number
is
None
or
self
.
max_trial_number
>
0
assert
self
.
max_trial_duration
is
None
or
utils
.
parse_time
(
self
.
max_trial_duration
)
>
0
assert
self
.
log_level
in
[
'fatal'
,
'error'
,
'warning'
,
'info'
,
'debug'
,
'trace'
]
# following line is disabled because it has side effect
# enable it if users encounter problems caused by failure in creating experiment directory
# currently I have only seen one issue of this kind
#Path(self.experiment_working_directory).mkdir(parents=True, exist_ok=True)
utils
.
validate_gpu_indices
(
self
.
tuner_gpu_indices
)
tuner_cnt
=
(
self
.
tuner
is
not
None
)
+
(
self
.
advisor
is
not
None
)
if
tuner_cnt
!=
1
:
raise
ValueError
(
'ExperimentConfig: tuner and advisor must be set one'
)
nni/experiment/config/local.py
deleted
100644 → 0
View file @
cb090e8c
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from
dataclasses
import
dataclass
from
typing
import
List
,
Optional
,
Union
from
.common
import
TrainingServiceConfig
from
.
import
util
__all__
=
[
'LocalConfig'
]
@
dataclass
(
init
=
False
)
class
LocalConfig
(
TrainingServiceConfig
):
platform
:
str
=
'local'
reuse_mode
:
bool
=
False
use_active_gpu
:
Optional
[
bool
]
=
None
max_trial_number_per_gpu
:
int
=
1
gpu_indices
:
Union
[
List
[
int
],
str
,
int
,
None
]
=
None
_canonical_rules
=
{
'gpu_indices'
:
util
.
canonical_gpu_indices
}
_validation_rules
=
{
'platform'
:
lambda
value
:
(
value
==
'local'
,
'cannot be modified'
),
'max_trial_number_per_gpu'
:
lambda
value
:
value
>
0
,
'gpu_indices'
:
lambda
value
:
all
(
idx
>=
0
for
idx
in
value
)
and
len
(
value
)
==
len
(
set
(
value
))
}
nni/experiment/config/remote.py
deleted
100644 → 0
View file @
cb090e8c
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from
dataclasses
import
dataclass
from
pathlib
import
Path
from
typing
import
List
,
Optional
,
Union
import
warnings
from
.base
import
ConfigBase
,
PathLike
from
.common
import
TrainingServiceConfig
from
.
import
util
__all__
=
[
'RemoteConfig'
,
'RemoteMachineConfig'
]
@
dataclass
(
init
=
False
)
class
RemoteMachineConfig
(
ConfigBase
):
host
:
str
port
:
int
=
22
user
:
str
password
:
Optional
[
str
]
=
None
ssh_key_file
:
PathLike
=
None
#'~/.ssh/id_rsa'
ssh_passphrase
:
Optional
[
str
]
=
None
use_active_gpu
:
bool
=
False
max_trial_number_per_gpu
:
int
=
1
gpu_indices
:
Union
[
List
[
int
],
str
,
int
,
None
]
=
None
python_path
:
Optional
[
str
]
=
None
_canonical_rules
=
{
'ssh_key_file'
:
util
.
canonical_path
,
'gpu_indices'
:
util
.
canonical_gpu_indices
}
_validation_rules
=
{
'port'
:
lambda
value
:
0
<
value
<
65536
,
'max_trial_number_per_gpu'
:
lambda
value
:
value
>
0
,
'gpu_indices'
:
lambda
value
:
all
(
idx
>=
0
for
idx
in
value
)
and
len
(
value
)
==
len
(
set
(
value
))
}
def
validate
(
self
):
super
().
validate
()
if
self
.
password
is
None
and
not
Path
(
self
.
ssh_key_file
).
is_file
():
raise
ValueError
(
f
'Password is not provided and cannot find SSH key file "
{
self
.
ssh_key_file
}
"'
)
if
self
.
password
:
warnings
.
warn
(
'Password will be exposed through web UI in plain text. We recommend to use SSH key file.'
)
@
dataclass
(
init
=
False
)
class
RemoteConfig
(
TrainingServiceConfig
):
platform
:
str
=
'remote'
reuse_mode
:
bool
=
True
machine_list
:
List
[
RemoteMachineConfig
]
def
__init__
(
self
,
**
kwargs
):
kwargs
=
util
.
case_insensitive
(
kwargs
)
kwargs
[
'machinelist'
]
=
util
.
load_config
(
RemoteMachineConfig
,
kwargs
.
get
(
'machinelist'
))
super
().
__init__
(
**
kwargs
)
_canonical_rules
=
{
'machine_list'
:
lambda
value
:
[
config
.
canonical
()
for
config
in
value
]
}
_validation_rules
=
{
'platform'
:
lambda
value
:
(
value
==
'remote'
,
'cannot be modified'
)
}
nni/experiment/config/shared_storage.py
View file @
d5857823
...
@@ -4,10 +4,23 @@
...
@@ -4,10 +4,23 @@
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
Optional
from
typing
import
Optional
from
.common
import
SharedStorageConfig
from
.base
import
ConfigBase
from
.utils
import
PathLike
__all__
=
[
'NfsConfig'
,
'AzureBlobConfig'
]
__all__
=
[
'NfsConfig'
,
'AzureBlobConfig'
]
@
dataclass
(
init
=
False
)
class
SharedStorageConfig
(
ConfigBase
):
storage_type
:
str
local_mount_point
:
PathLike
remote_mount_point
:
str
local_mounted
:
str
storage_account_name
:
Optional
[
str
]
=
None
storage_account_key
:
Optional
[
str
]
=
None
container_name
:
Optional
[
str
]
=
None
nfs_server
:
Optional
[
str
]
=
None
exported_directory
:
Optional
[
str
]
=
None
@
dataclass
(
init
=
False
)
@
dataclass
(
init
=
False
)
class
NfsConfig
(
SharedStorageConfig
):
class
NfsConfig
(
SharedStorageConfig
):
storage_type
:
str
=
'NFS'
storage_type
:
str
=
'NFS'
...
@@ -19,5 +32,4 @@ class AzureBlobConfig(SharedStorageConfig):
...
@@ -19,5 +32,4 @@ class AzureBlobConfig(SharedStorageConfig):
storage_type
:
str
=
'AzureBlob'
storage_type
:
str
=
'AzureBlob'
storage_account_name
:
str
storage_account_name
:
str
storage_account_key
:
Optional
[
str
]
=
None
storage_account_key
:
Optional
[
str
]
=
None
resource_group_name
:
Optional
[
str
]
=
None
container_name
:
str
container_name
:
str
nni/experiment/config/training_service.py
0 → 100644
View file @
d5857823
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
"""
``TrainingServiceConfig`` class.
Docstrings in this file are mainly for NNI contributors, or training service authors.
"""
__all__
=
[
'TrainingServiceConfig'
]
from
dataclasses
import
dataclass
from
pathlib
import
Path
from
typing
import
Optional
from
.base
import
ConfigBase
from
.utils
import
PathLike
,
is_missing
@
dataclass
(
init
=
False
)
class
TrainingServiceConfig
(
ConfigBase
):
"""
The base class of training service config classes.
See ``LocalConfig`` for example usage.
"""
platform
:
str
trial_command
:
str
trial_code_directory
:
PathLike
trial_gpu_number
:
Optional
[
int
]
nni_manager_ip
:
Optional
[
str
]
debug
:
bool
def
_canonicalize
(
self
,
parents
):
"""
Besides from ``ConfigBase._canonicalize()``, this overloaded version will also
copy training service specific fields from ``ExperimentConfig``.
"""
shortcuts
=
[
# fields that can set in root level config as shortcut
'trial_command'
,
'trial_code_directory'
,
'trial_gpu_number'
,
'nni_manager_ip'
,
'debug'
,
]
for
field_name
in
shortcuts
:
if
is_missing
(
getattr
(
self
,
field_name
)):
value
=
getattr
(
parents
[
0
],
field_name
)
setattr
(
self
,
field_name
,
value
)
super
().
_canonicalize
(
parents
)
def
_validate_canonical
(
self
):
super
().
_validate_canonical
()
cls
=
type
(
self
)
assert
self
.
platform
==
cls
.
platform
if
not
Path
(
self
.
trial_code_directory
).
is_dir
():
raise
ValueError
(
f
'
{
cls
.
__name__
}
: trial_code_directory "
{
self
.
trial_code_directory
}
" is not a directory'
)
assert
self
.
trial_gpu_number
is
None
or
self
.
trial_gpu_number
>=
0
nni/experiment/config/training_services/__init__.py
0 → 100644
View file @
d5857823
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from
.local
import
*
from
.remote
import
*
from
.openpai
import
*
from
.k8s_storage
import
*
from
.kubeflow
import
*
from
.frameworkcontroller
import
*
from
.aml
import
*
from
.dlc
import
*
nni/experiment/config/aml.py
→
nni/experiment/config/
training_services/
aml.py
View file @
d5857823
# Copyright (c) Microsoft Corporation.
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# Licensed under the MIT license.
from
dataclasses
import
dataclass
"""
Configuration for AML training service.
Check the reference_ for explaination of each field.
You may also want to check `AML training service doc`_.
.. _reference: https://nni.readthedocs.io/en/stable/reference/experiment_config.html
.. _AML training service doc: https://nni.readthedocs.io/en/stable/TrainingService/AMLMode.html
from
.common
import
TrainingServiceConfig
"""
__all__
=
[
'AmlConfig'
]
__all__
=
[
'AmlConfig'
]
from
dataclasses
import
dataclass
from
..training_service
import
TrainingServiceConfig
@
dataclass
(
init
=
False
)
@
dataclass
(
init
=
False
)
class
AmlConfig
(
TrainingServiceConfig
):
class
AmlConfig
(
TrainingServiceConfig
):
platform
:
str
=
'aml'
platform
:
str
=
'aml'
...
@@ -16,7 +29,3 @@ class AmlConfig(TrainingServiceConfig):
...
@@ -16,7 +29,3 @@ class AmlConfig(TrainingServiceConfig):
compute_target
:
str
compute_target
:
str
docker_image
:
str
=
'msranni/nni:latest'
docker_image
:
str
=
'msranni/nni:latest'
max_trial_number_per_gpu
:
int
=
1
max_trial_number_per_gpu
:
int
=
1
_validation_rules
=
{
'platform'
:
lambda
value
:
(
value
==
'aml'
,
'cannot be modified'
)
}
nni/experiment/config/dlc.py
→
nni/experiment/config/
training_services/
dlc.py
View file @
d5857823
...
@@ -3,7 +3,7 @@
...
@@ -3,7 +3,7 @@
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
.
common
import
TrainingServiceConfig
from
.
.training_service
import
TrainingServiceConfig
__all__
=
[
'DlcConfig'
]
__all__
=
[
'DlcConfig'
]
...
@@ -21,7 +21,3 @@ class DlcConfig(TrainingServiceConfig):
...
@@ -21,7 +21,3 @@ class DlcConfig(TrainingServiceConfig):
access_key_secret
:
str
access_key_secret
:
str
local_storage_mount_point
:
str
local_storage_mount_point
:
str
container_storage_mount_point
:
str
container_storage_mount_point
:
str
_validation_rules
=
{
'platform'
:
lambda
value
:
(
value
==
'dlc'
,
'cannot be modified'
)
}
nni/experiment/config/frameworkcontroller.py
→
nni/experiment/config/
training_services/
frameworkcontroller.py
View file @
d5857823
# Copyright (c) Microsoft Corporation.
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# Licensed under the MIT license.
from
dataclasses
import
dataclass
"""
from
typing
import
List
,
Optional
Configuration for FrameworkController training service.
from
.base
import
ConfigBase
Check the reference_ for explaination of each field.
from
.common
import
TrainingServiceConfig
from
.
import
util
__all__
=
[
You may also want to check `FrameworkController training service doc`_.
'FrameworkControllerConfig'
,
'FrameworkControllerRoleConfig'
,
'_FrameworkControllerStorageConfig'
]
.. _reference: https://nni.readthedocs.io/en/stable/reference/experiment_config.html
@
dataclass
(
init
=
False
)
.. _FrameworkController training service doc: https://nni.readthedocs.io/en/stable/TrainingService/FrameworkControllerMode.html
class
_FrameworkControllerStorageConfig
(
ConfigBase
):
storage_type
:
str
"""
server
:
Optional
[
str
]
=
None
path
:
Optional
[
str
]
=
None
__all__
=
[
'FrameworkControllerConfig'
,
'FrameworkControllerRoleConfig'
,
'FrameworkAttemptCompletionPolicy'
]
azure_account
:
Optional
[
str
]
=
None
azure_share
:
Optional
[
str
]
=
None
from
dataclasses
import
dataclass
key_vault_name
:
Optional
[
str
]
=
None
from
typing
import
List
,
Optional
,
Union
key_vault_key
:
Optional
[
str
]
=
None
from
..base
import
ConfigBase
from
..training_service
import
TrainingServiceConfig
from
.k8s_storage
import
K8sStorageConfig
@
dataclass
(
init
=
False
)
@
dataclass
(
init
=
False
)
class
FrameworkAttemptCompletionPolicy
(
ConfigBase
):
class
FrameworkAttemptCompletionPolicy
(
ConfigBase
):
...
@@ -38,25 +36,13 @@ class FrameworkControllerRoleConfig(ConfigBase):
...
@@ -38,25 +36,13 @@ class FrameworkControllerRoleConfig(ConfigBase):
command
:
str
command
:
str
gpu_number
:
int
gpu_number
:
int
cpu_number
:
int
cpu_number
:
int
memory_size
:
str
memory_size
:
Union
[
str
,
int
]
framework_attempt_completion_policy
:
FrameworkAttemptCompletionPolicy
framework_attempt_completion_policy
:
FrameworkAttemptCompletionPolicy
@
dataclass
(
init
=
False
)
@
dataclass
(
init
=
False
)
class
FrameworkControllerConfig
(
TrainingServiceConfig
):
class
FrameworkControllerConfig
(
TrainingServiceConfig
):
platform
:
str
=
'frameworkcontroller'
platform
:
str
=
'frameworkcontroller'
service_account_name
:
str
storage
:
K8sStorageConfig
storage
:
_FrameworkControllerStorageConfig
task_roles
:
List
[
FrameworkControllerRoleConfig
]
reuse_mode
:
Optional
[
bool
]
=
True
#set reuse mode as true for v2 config
service_account_name
:
Optional
[
str
]
service_account_name
:
Optional
[
str
]
task_roles
:
List
[
FrameworkControllerRoleConfig
]
def
__init__
(
self
,
**
kwargs
):
reuse_mode
:
Optional
[
bool
]
=
True
kwargs
=
util
.
case_insensitive
(
kwargs
)
kwargs
[
'storage'
]
=
util
.
load_config
(
_FrameworkControllerStorageConfig
,
kwargs
.
get
(
'storage'
))
kwargs
[
'taskroles'
]
=
util
.
load_config
(
FrameworkControllerRoleConfig
,
kwargs
.
get
(
'taskroles'
))
super
().
__init__
(
**
kwargs
)
_validation_rules
=
{
'platform'
:
lambda
value
:
(
value
==
'frameworkcontroller'
,
'cannot be modified'
)
}
nni/experiment/config/training_services/k8s_storage.py
0 → 100644
View file @
d5857823
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
"""
Storage config classes for ``KubeflowConfig`` and ``FrameworkControllerConfig``
"""
__all__
=
[
'K8sStorageConfig'
,
'K8sAzureStorageConfig'
,
'K8sNfsConfig'
]
from
dataclasses
import
dataclass
from
typing
import
Optional
from
..base
import
ConfigBase
@
dataclass
(
init
=
False
)
class
K8sStorageConfig
(
ConfigBase
):
storage_type
:
str
azure_account
:
Optional
[
str
]
=
None
azure_share
:
Optional
[
str
]
=
None
key_vault_name
:
Optional
[
str
]
=
None
key_vault_key
:
Optional
[
str
]
=
None
server
:
Optional
[
str
]
=
None
path
:
Optional
[
str
]
=
None
def
_validate_canonical
(
self
):
super
().
_validate_canonical
()
if
self
.
storage_type
==
'azureStorage'
:
assert
self
.
server
is
None
and
self
.
path
is
None
elif
self
.
storage_type
==
'nfs'
:
assert
self
.
azure_account
is
None
and
self
.
azure_share
is
None
assert
self
.
key_vault_name
is
None
and
self
.
key_vault_key
is
None
else
:
raise
ValueError
(
f
'Kubernetes storage_type ("
{
self
.
storage_type
}
") must either be "azureStorage" or "nfs"'
)
@
dataclass
(
init
=
False
)
class
K8sNfsConfig
(
K8sStorageConfig
):
storage
:
str
=
'nfs'
server
:
str
path
:
str
@
dataclass
(
init
=
False
)
class
K8sAzureStorageConfig
(
K8sStorageConfig
):
storage
:
str
=
'azureStorage'
azure_account
:
str
azure_share
:
str
key_vault_name
:
str
key_vault_key
:
str
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment