Unverified Commit 10e56560 authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

Add shortcut to merge parameter into base setup (#2540)

parent 97f9d8a9
......@@ -78,3 +78,9 @@
.. autoclass:: nni.bohb_advisor.bohb_advisor.BOHB
:members:
```
## Utilities
```eval_rst
.. autofunction:: nni.utils.merge_parameter
```
......@@ -13,6 +13,7 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from nni.utils import merge_parameter
from torchvision import datasets, transforms
logger = logging.getLogger('mnist_AutoML')
......@@ -157,8 +158,7 @@ if __name__ == '__main__':
# get parameters form tuner
tuner_params = nni.get_next_parameter()
logger.debug(tuner_params)
params = vars(get_params())
params.update(tuner_params)
params = vars(merge_parameter(get_params(), tuner_params))
print(params)
main(params)
except Exception as exception:
......
......@@ -216,3 +216,43 @@ def json2parameter(x, is_rand, random_state, oldy=None, Rand=False, name=NodeTyp
else:
y = copy.deepcopy(x)
return y
def merge_parameter(base_params, override_params):
"""
Update the parameters in ``base_params`` with ``override_params``.
Can be useful to override parsed command line arguments.
Parameters
----------
base_params : namespace or dict
Base parameters. A key-value mapping.
override_params : dict or None
Parameters to override. Usually the parameters got from ``get_next_parameters()``.
When it is none, nothing will happen.
Returns
-------
namespace or dict
The updated ``base_params``. Note that ``base_params`` will be updated inplace. The return value is
only for convenience.
"""
if override_params is None:
return base_params
is_dict = isinstance(base_params, dict)
for k, v in override_params.items():
if is_dict:
if k not in base_params:
raise ValueError('Key \'%s\' not found in base parameters.' % k)
if type(base_params[k]) != type(v) and base_params[k] is not None:
raise TypeError('Expected \'%s\' in override parameters to have type \'%s\', but found \'%s\'.' %
(k, type(base_params[k]), type(v)))
base_params[k] = v
else:
if not hasattr(base_params, k):
raise ValueError('Key \'%s\' not found in base parameters.' % k)
if type(getattr(base_params, k)) != type(v) and getattr(base_params, k) is not None:
raise TypeError('Expected \'%s\' in override parameters to have type \'%s\', but found \'%s\'.' %
(k, type(getattr(base_params, k)), type(v)))
setattr(base_params, k, v)
return base_params
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment