Commit ea89e232 authored by Yuge Zhang's avatar Yuge Zhang Committed by Guoxin
Browse files

refactor get_current_parameter to avoid using trial._params (#1458)

parent 427f1256
...@@ -32,7 +32,7 @@ def classic_mode( ...@@ -32,7 +32,7 @@ def classic_mode(
'''Execute the chosen function and inputs directly. '''Execute the chosen function and inputs directly.
In this mode, the trial code is only running the chosen subgraph (i.e., the chosen ops and inputs), In this mode, the trial code is only running the chosen subgraph (i.e., the chosen ops and inputs),
without touching the full model graph.''' without touching the full model graph.'''
if trial._params is None: if trial.get_current_parameter() is None:
trial.get_next_parameter() trial.get_next_parameter()
mutable_block = trial.get_current_parameter(mutable_id) mutable_block = trial.get_current_parameter(mutable_id)
chosen_layer = mutable_block[mutable_layer_id]["chosen_layer"] chosen_layer = mutable_block[mutable_layer_id]["chosen_layer"]
...@@ -118,7 +118,7 @@ def oneshot_mode( ...@@ -118,7 +118,7 @@ def oneshot_mode(
The difference is that oneshot mode does not receive subgraph. The difference is that oneshot mode does not receive subgraph.
Instead, it uses dropout to randomly dropout inputs and ops.''' Instead, it uses dropout to randomly dropout inputs and ops.'''
# NNI requires to get_next_parameter before report a result. But the parameter will not be used in this mode # NNI requires to get_next_parameter before report a result. But the parameter will not be used in this mode
if trial._params is None: if trial.get_current_parameter() is None:
trial.get_next_parameter() trial.get_next_parameter()
optional_inputs = list(optional_inputs.values()) optional_inputs = list(optional_inputs.values())
inputs_num = len(optional_inputs) inputs_num = len(optional_inputs)
......
...@@ -189,6 +189,6 @@ else: ...@@ -189,6 +189,6 @@ else:
raise RuntimeError('Unrecognized mode: %s' % mode) raise RuntimeError('Unrecognized mode: %s' % mode)
def _get_param(key): def _get_param(key):
if trial._params is None: if trial.get_current_parameter() is None:
trial.get_next_parameter() trial.get_next_parameter()
return trial.get_current_parameter(key) return trial.get_current_parameter(key)
...@@ -50,10 +50,12 @@ def get_next_parameter(): ...@@ -50,10 +50,12 @@ def get_next_parameter():
return None return None
return _params['parameters'] return _params['parameters']
def get_current_parameter(tag): def get_current_parameter(tag=None):
global _params global _params
if _params is None: if _params is None:
return None return None
if tag is None:
return _params['parameters']
return _params['parameters'][tag] return _params['parameters'][tag]
def get_experiment_id(): def get_experiment_id():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment