Unverified Commit 10e56560 authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

Add shortcut to merge parameter into base setup (#2540)

parent 97f9d8a9
...@@ -78,3 +78,9 @@ ...@@ -78,3 +78,9 @@
.. autoclass:: nni.bohb_advisor.bohb_advisor.BOHB .. autoclass:: nni.bohb_advisor.bohb_advisor.BOHB
:members: :members:
``` ```
## Utilities
```eval_rst
.. autofunction:: nni.utils.merge_parameter
```
...@@ -13,6 +13,7 @@ import torch ...@@ -13,6 +13,7 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import torch.optim as optim import torch.optim as optim
from nni.utils import merge_parameter
from torchvision import datasets, transforms from torchvision import datasets, transforms
logger = logging.getLogger('mnist_AutoML') logger = logging.getLogger('mnist_AutoML')
...@@ -157,8 +158,7 @@ if __name__ == '__main__': ...@@ -157,8 +158,7 @@ if __name__ == '__main__':
# get parameters form tuner # get parameters form tuner
tuner_params = nni.get_next_parameter() tuner_params = nni.get_next_parameter()
logger.debug(tuner_params) logger.debug(tuner_params)
params = vars(get_params()) params = vars(merge_parameter(get_params(), tuner_params))
params.update(tuner_params)
print(params) print(params)
main(params) main(params)
except Exception as exception: except Exception as exception:
......
...@@ -216,3 +216,43 @@ def json2parameter(x, is_rand, random_state, oldy=None, Rand=False, name=NodeTyp ...@@ -216,3 +216,43 @@ def json2parameter(x, is_rand, random_state, oldy=None, Rand=False, name=NodeTyp
else: else:
y = copy.deepcopy(x) y = copy.deepcopy(x)
return y return y
def merge_parameter(base_params, override_params):
"""
Update the parameters in ``base_params`` with ``override_params``.
Can be useful to override parsed command line arguments.
Parameters
----------
base_params : namespace or dict
Base parameters. A key-value mapping.
override_params : dict or None
Parameters to override. Usually the parameters got from ``get_next_parameters()``.
When it is none, nothing will happen.
Returns
-------
namespace or dict
The updated ``base_params``. Note that ``base_params`` will be updated inplace. The return value is
only for convenience.
"""
if override_params is None:
return base_params
is_dict = isinstance(base_params, dict)
for k, v in override_params.items():
if is_dict:
if k not in base_params:
raise ValueError('Key \'%s\' not found in base parameters.' % k)
if type(base_params[k]) != type(v) and base_params[k] is not None:
raise TypeError('Expected \'%s\' in override parameters to have type \'%s\', but found \'%s\'.' %
(k, type(base_params[k]), type(v)))
base_params[k] = v
else:
if not hasattr(base_params, k):
raise ValueError('Key \'%s\' not found in base parameters.' % k)
if type(getattr(base_params, k)) != type(v) and getattr(base_params, k) is not None:
raise TypeError('Expected \'%s\' in override parameters to have type \'%s\', but found \'%s\'.' %
(k, type(getattr(base_params, k)), type(v)))
setattr(base_params, k, v)
return base_params
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment