"...composable_kernel_rocm.git" did not exist on "8f360206cc3168df3a1353cfd382d769d7130789"
Commit be2fbe27 authored by xuehui's avatar xuehui
Browse files

fix bugs

parent 71688b86
...@@ -2,7 +2,7 @@ authorName: default ...@@ -2,7 +2,7 @@ authorName: default
experimentName: example_mnist-keras experimentName: example_mnist-keras
trialConcurrency: 1 trialConcurrency: 1
maxExecDuration: 1h maxExecDuration: 1h
maxTrialNum: 10 maxTrialNum: 6
#choice: local, remote #choice: local, remote
trainingServicePlatform: local trainingServicePlatform: local
searchSpacePath: ~/nni/examples/trials/mnist-batch-tune-keras/search_space.json searchSpacePath: ~/nni/examples/trials/mnist-batch-tune-keras/search_space.json
......
...@@ -28,6 +28,7 @@ import random ...@@ -28,6 +28,7 @@ import random
import numpy as np import numpy as np
import nni
from nni.tuner import Tuner from nni.tuner import Tuner
TYPE = '_type' TYPE = '_type'
...@@ -73,7 +74,7 @@ class BatchTuner(Tuner): ...@@ -73,7 +74,7 @@ class BatchTuner(Tuner):
def generate_parameters(self, parameter_id): def generate_parameters(self, parameter_id):
self.count +=1 self.count +=1
if self.count>len(self.values)-1: if self.count>len(self.values)-1:
return None raise nni.NoMoreTrialError('no more parameters now.')
return self.values[self.count] return self.values[self.count]
def receive_trial_result(self, parameter_id, parameters, reward): def receive_trial_result(self, parameter_id, parameters, reward):
......
...@@ -89,7 +89,6 @@ class MsgDispatcher(MsgDispatcherBase): ...@@ -89,7 +89,6 @@ class MsgDispatcher(MsgDispatcherBase):
# data: number or trial jobs # data: number or trial jobs
ids = [_create_parameter_id() for _ in range(data)] ids = [_create_parameter_id() for _ in range(data)]
params_list = self.tuner.generate_multiple_parameters(ids) params_list = self.tuner.generate_multiple_parameters(ids)
#assert len(ids) == len(params_list)
# when parameters is None. # when parameters is None.
if len(params_list) == 0: if len(params_list) == 0:
......
...@@ -21,6 +21,7 @@ ...@@ -21,6 +21,7 @@
import logging import logging
import nni
from .recoverable import Recoverable from .recoverable import Recoverable
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
...@@ -44,9 +45,11 @@ class Tuner(Recoverable): ...@@ -44,9 +45,11 @@ class Tuner(Recoverable):
""" """
result = [] result = []
for parameter_id in parameter_id_list: for parameter_id in parameter_id_list:
temp = self.generate_parameters(parameter_id) try:
if temp: res = self.generate_parameters(parameter_id)
result.append(temp) except nni.NoMoreTrialError:
return result
result.append(res)
return result return result
def receive_trial_result(self, parameter_id, parameters, reward): def receive_trial_result(self, parameter_id, parameters, reward):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment