Unverified Commit 31f11f51 authored by liuzhe-lz's avatar liuzhe-lz Committed by GitHub
Browse files

Fix search space compatibility with JSON (#4455)

parent 452e69f3
...@@ -8,7 +8,9 @@ Top level experiement configuration class, ``ExperimentConfig``. ...@@ -8,7 +8,9 @@ Top level experiement configuration class, ``ExperimentConfig``.
__all__ = ['ExperimentConfig'] __all__ = ['ExperimentConfig']
from dataclasses import dataclass from dataclasses import dataclass
import json
import logging import logging
from pathlib import Path
from typing import Any, List, Optional, Union from typing import Any, List, Optional, Union
import yaml import yaml
...@@ -113,6 +115,16 @@ class ExperimentConfig(ConfigBase): ...@@ -113,6 +115,16 @@ class ExperimentConfig(ConfigBase):
super()._canonicalize([self]) super()._canonicalize([self])
if self.search_space_file is not None:
yaml_error = None
try:
self.search_space = _load_search_space_file(self.search_space_file)
except Exception as e:
yaml_error = repr(e)
if yaml_error is not None: # raise it outside except block to make stack trace clear
msg = f'ExperimentConfig: Failed to load search space file "{self.search_space_file}": {yaml_error}'
raise ValueError(msg)
if self.nni_manager_ip is None: if self.nni_manager_ip is None:
# show a warning if user does not set nni_manager_ip. we have many issues caused by this # show a warning if user does not set nni_manager_ip. we have many issues caused by this
# the simple detection logic won't work for hybrid, but advanced users should not need it # the simple detection logic won't work for hybrid, but advanced users should not need it
...@@ -133,10 +145,6 @@ class ExperimentConfig(ConfigBase): ...@@ -133,10 +145,6 @@ class ExperimentConfig(ConfigBase):
if not self.use_annotation and space_cnt < 1: if not self.use_annotation and space_cnt < 1:
raise ValueError('ExperimentConfig: search_space and search_space_file must be set one') raise ValueError('ExperimentConfig: search_space and search_space_file must be set one')
if self.search_space_file is not None:
with open(self.search_space_file) as ss_file:
self.search_space = yaml.safe_load(ss_file)
# to make the error message clear, ideally it should be: # to make the error message clear, ideally it should be:
# `if concurrency < 0: raise ValueError('trial_concurrency ({concurrency}) must greater than 0')` # `if concurrency < 0: raise ValueError('trial_concurrency ({concurrency}) must greater than 0')`
# but I believe there will be hardy few users make this kind of mistakes, so let's keep it simple # but I believe there will be hardy few users make this kind of mistakes, so let's keep it simple
...@@ -156,3 +164,13 @@ class ExperimentConfig(ConfigBase): ...@@ -156,3 +164,13 @@ class ExperimentConfig(ConfigBase):
tuner_cnt = (self.tuner is not None) + (self.advisor is not None) tuner_cnt = (self.tuner is not None) + (self.advisor is not None)
if tuner_cnt != 1: if tuner_cnt != 1:
raise ValueError('ExperimentConfig: tuner and advisor must be set one') raise ValueError('ExperimentConfig: tuner and advisor must be set one')
def _load_search_space_file(search_space_path):
# FIXME
# we need this because PyYAML 6.0 does not support YAML 1.2,
# which means it is not fully compatible with JSON
content = Path(search_space_path).read_text(encoding='utf8')
try:
return json.loads(content)
except Exception:
return yaml.safe_load(content)
pool_type:
_type: choice
_value:
- max
- min
- avg
学习率: # test unicode
_type: loguniform
_value: [ 0.0000001, 0.1 ]
{
"pool_type": {
"_type": "choice",
"_value": [ "max", "min", "avg" ],
},
"学习率": {
"_type": "loguniform",
"_value": [ 0.0000001, 0.1 ],
},
}
{
"pool_type": {
"_type": "choice",
"_value": [ "max", "min", "avg" ]
},
"学习率": {
"_type": "loguniform",
"_value": [ 1e-7, 0.1 ]
}
}
{
"pool_type": {
"_type": "choice",
"_value": [ "max", "min", "avg" ],
},
"学习率": {
"_type": "loguniform",
"_value": [ 1e-7, 0.1 ],
},
}
pool_type:
_type: choice
_value:
- max
- min
- avg
学习率: # test unicode
_type: loguniform
_value: [ 1e-7, 0.1 ] # test scientific notation
import json
from pathlib import Path
import yaml
from nni.experiment.config import ExperimentConfig, AlgorithmConfig, LocalConfig
## template ##
config = ExperimentConfig(
search_space_file = '',
trial_command = 'echo hello',
trial_concurrency = 1,
tuner = AlgorithmConfig(name='randomm'),
training_service = LocalConfig()
)
space_correct = {
'pool_type': {
'_type': 'choice',
'_value': ['max', 'min', 'avg']
},
'学习率': {
'_type': 'loguniform',
'_value': [1e-7, 0.1]
}
}
# FIXME
# PyYAML 6.0 (YAML 1.1) does not support tab and scientific notation
# JSON does not support comment and extra comma
# So some combinations will fail to load
formats = [
('ss_tab.json', 'JSON (tabs + scientific notation)'),
('ss_comma.json', 'JSON with extra comma'),
#('ss_tab_comma.json', 'JSON (tabs + scientific notation) with extra comma'),
('ss.yaml', 'YAML'),
#('ss_yaml12.yaml', 'YAML 1.2 with scientific notation'),
]
def test_search_space():
for space_file, description in formats:
try:
config.search_space_file = Path(__file__).parent / 'assets' / space_file
space = config.json()['searchSpace']
assert space == space_correct
except Exception as e:
print('Failed to load search space format: ' + description)
raise e
if __name__ == '__main__':
test_search_space()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment