random_nas_tuner.py 2.22 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import numpy as np

from nni.tuner import Tuner

def random_archi_generator(nas_ss, random_state):
    '''random
    '''
    chosen_archi = {}
    print("zql: nas search space: ", nas_ss)
    for block_name, block in nas_ss.items():
        tmp_block = {}
        for layer_name, layer in block.items():
            tmp_layer = {}
            for key, value in layer.items():
                if key == 'layer_choice':
                    index = random_state.randint(len(value))
                    tmp_layer['chosen_layer'] = value[index]
                elif key == 'optional_inputs':
                    tmp_layer['chosen_inputs'] = []
                    print("zql: optional_inputs", layer['optional_inputs'])
                    if layer['optional_inputs']:
                        if isinstance(layer['optional_input_size'], int):
                            choice_num = layer['optional_input_size']
                        else:
                            choice_range = layer['optional_input_size']
                            choice_num = random_state.randint(choice_range[0], choice_range[1]+1)
                        for _ in range(choice_num):
                            index = random_state.randint(len(layer['optional_inputs']))
                            tmp_layer['chosen_inputs'].append(layer['optional_inputs'][index])
                elif key == 'optional_input_size':
                    pass
                else:
                    raise ValueError('Unknown field %s in layer %s of block %s' % (key, layer_name, block_name))
            tmp_block[layer_name] = tmp_layer
        chosen_archi[block_name] = tmp_block
    return chosen_archi

class RandomNASTuner(Tuner):
    '''RandomNASTuner
    '''

    def __init__(self):
        self.searchspace_json = None
        self.random_state = None

    def update_search_space(self, search_space):
        '''update
        '''
        self.searchspace_json = search_space
        self.random_state = np.random.RandomState()

    def generate_parameters(self, parameter_id):
        '''generate
        '''
        return random_archi_generator(self.searchspace_json, self.random_state)

    def receive_trial_result(self, parameter_id, parameters, value):
        '''receive
        '''
        pass