"include/ck/utility/utility.hpp" did not exist on "05e046654c9a226444091806a418a77fe0e4a4c2"
simple_tuner.py 2.04 KB
Newer Older
Yan Ni's avatar
Yan Ni committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
"""
SimpleTuner for Weight Sharing
"""

import logging

from threading import Event, Lock
from nni.tuner import Tuner

_logger = logging.getLogger('WeightSharingTuner')


class SimpleTuner(Tuner):
    """
    simple tuner, test for weight sharing
    """

    def __init__(self):
        super(SimpleTuner, self).__init__()
        self.trial_meta = {}
        self.f_id = None  # father
        self.sig_event = Event()
        self.thread_lock = Lock()

    def generate_parameters(self, parameter_id):
        if self.f_id is None:
            self.thread_lock.acquire()
            self.f_id = parameter_id
            self.trial_meta[parameter_id] = {
                'prev_id': 0,
                'id': parameter_id,
                'checksum': None,
                'path': '',
            }
            _logger.info('generate parameter for father trial %s' %
                         parameter_id)
            self.thread_lock.release()
            return {
                'prev_id': 0,
                'id': parameter_id,
            }
        else:
            self.sig_event.wait()
            self.thread_lock.acquire()
            self.trial_meta[parameter_id] = {
                'id': parameter_id,
                'prev_id': self.f_id,
                'prev_path': self.trial_meta[self.f_id]['path']
            }
            self.thread_lock.release()
            return self.trial_meta[parameter_id]

    def receive_trial_result(self, parameter_id, parameters, reward):
        self.thread_lock.acquire()
        if parameter_id == self.f_id:
            self.trial_meta[parameter_id]['checksum'] = reward['checksum']
            self.trial_meta[parameter_id]['path'] = reward['path']
            self.sig_event.set()
        else:
            if reward['checksum'] != self.trial_meta[self.f_id]['checksum']:
                raise ValueError("Inconsistency in weight sharing: {} != {}".format(
                    reward['checksum'], self.trial_meta[self.f_id]['checksum']))
        self.thread_lock.release()

    def update_search_space(self, search_space):
        pass