_store_actions.py 4.87 KB
Newer Older
root's avatar
root committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
import enum
import threading

from cupyx.distributed import _klv_utils


class Actions(enum.IntEnum):
    Set = 1
    Get = 2
    Barrier = 3


class ActionError:
    def __init__(self, exception):
        self._exception = exception

    def klv(self):
        e = self._exception
        return _klv_utils.get_result_action_t(1, str(e).encode('ascii'))

    @staticmethod
    def from_klv(klv):
        raise RuntimeError(klv._exception.decode('utf-8'))

    def decode_result(self, data):
        ActionError.from_klv(data)


def execute_action(action, value, store):
    # receive the remaining amount of bytes that L field specifies
    try:
        if action == Actions.Set:
            action_obj = Set.from_klv(value)
        elif action == Actions.Get:
            action_obj = Get.from_klv(value)
        elif action == Actions.Barrier:
            action_obj = Barrier.from_klv(value)
        else:
            raise ValueError(f'unknown action {action}')
        return action_obj(store)
    except Exception as e:
        return ActionError(e)


class Set:
    class SetResult:
        def klv(self):
            v = bytearray(bytes(True))
            action = _klv_utils.get_result_action_t(0, v)
            return bytes(action)

        @staticmethod
        def from_klv(klv):
            return True

    def __init__(self, key, value):
        self.key = key
        self.value = value
        if not isinstance(key, str):
            raise ValueError('Invalid type for key, only str allowed')
        if type(value) not in (bytes, bytearray, int):
            raise ValueError(
                'Invalid type for value, only int or bytes allowed')
        # Check, value can only be integer or bytes

    @staticmethod
    def from_klv(value):
        value = bytes(value)
        for i, b in enumerate(value):
            if b == 0:
                k = value[:i].decode('utf-8')
                value = value[i + 1:]
                break
        else:
            raise ValueError('No separation character for key found')
        v = _klv_utils.get_value_from_bytes(value)
        return Set(k, v)

    def klv(self):
        v = bytearray(self.key.encode('ascii'))
        v.append(0)  # marker to show where the value begins
        v += _klv_utils.create_value_bytes(self.value)
        action = _klv_utils.get_action_t(Actions.Set, v)
        return bytes(action)

    def __call__(self, store):
        store.storage[self.key] = self.value
        return Set.SetResult()

    def decode_result(self, data):
        return Set.SetResult.from_klv(data)


class Get:
    class GetResult:
        def __init__(self, value):
            self.value = value

        def klv(self):
            v = _klv_utils.create_value_bytes(self.value)
            action = _klv_utils.get_result_action_t(0, v)
            return bytes(action)

        @staticmethod
        def from_klv(value):
            value = bytearray(value)
            return _klv_utils.get_value_from_bytes(value)

    def __init__(self, key):
        self.key = key
        if not isinstance(key, str):
            raise ValueError('Invalid type for key, only str allowed')

    @staticmethod
    def from_klv(value):
        k = value.decode('utf-8')
        return Get(k)

    def klv(self):
        v = bytearray(self.key.encode('ascii'))
        action = _klv_utils.get_action_t(Actions.Get, v)
        return bytes(action)

    def __call__(self, store):
        return Get.GetResult(store.storage[self.key])

    def decode_result(self, data):
        return Get.GetResult.from_klv(data)


class _BarrierImpl:
    def __init__(self, world_size):
        self._world_size = world_size
        self._cvar = threading.Condition()

    def __call__(self):
        # Superlame implementation, should be improved
        with self._cvar:
            self._world_size -= 1
            if self._world_size == 0:
                self._cvar.notifyAll()
            elif self._world_size > 0:
                self._cvar.wait()


class Barrier:
    class BarrierResult:
        def klv(self):
            v = bytearray(bytes(True))
            action = _klv_utils.get_result_action_t(0, v)
            return bytes(action)

        @staticmethod
        def from_klv(klv):
            # don't need to parse, barrier always sends True
            return True

    def klv(self):
        action = _klv_utils.get_action_t(Actions.Barrier, bytes(0))
        return bytes(action)

    @staticmethod
    def from_klv(klv):
        return Barrier()

    def __call__(self, store):
        with store._lock:
            if store._current_barrier is None:
                store._current_barrier = _BarrierImpl(store._world_size)
        store._current_barrier()
        # Once the barrier has been completed, just clean it
        with store._lock:
            store._current_barrier = None
        return Barrier.BarrierResult()

    def decode_result(self, data):
        return Barrier.BarrierResult.from_klv(data)