_store.py 4.89 KB
Newer Older
root's avatar
root committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
import atexit
from ctypes import sizeof
import multiprocessing
import threading
import socket
import time

from cupyx.distributed import _klv_utils
from cupyx.distributed import _store_actions


_DEFAULT_HOST = '127.0.0.1'
_DEFAULT_PORT = 13333

_exit_mode = False


@atexit.register
def _exit():
    global _exit_mode
    _exit_mode = True


class ExceptionAwareProcess(multiprocessing.Process):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self._exception = None
        self._parent_p, self._child_p = multiprocessing.Pipe()

    def run(self):
        try:
            super().run()
            self._child_p.send(None)
        except Exception as e:
            self._child_p.send(e)

    def join(self):
        super().join()
        if self._parent_p.poll():
            exception = self._parent_p.recv()
            if exception is not None:
                raise exception


class TCPStore:
    # This is only used for initialization of nccl so we don't care
    # too much about peformance
    def __init__(self, world_size):
        self.storage = {}
        self._process = None
        self._world_size = world_size
        self._run = multiprocessing.Value('b', 1)
        # For implementing a barrier
        self._lock = threading.Lock()
        self._current_barrier = None

    def __del__(self):
        if not _exit_mode:
            self.stop()

    def _set_process(self, process):
        self._process = process

    def _process_request(self, c_socket):
        with c_socket:
            # Receive in KLV format
            action_bytes = c_socket.recv(sizeof(_klv_utils.action_t))
            if len(action_bytes) > 0:
                action_m = _klv_utils.action_t.from_buffer_copy(action_bytes)
                if action_m.length > 256:
                    raise ValueError('Invalid length for message')
                value = bytearray(action_m.value)[:action_m.length]
                r = _store_actions.execute_action(action_m.action, value, self)
                if r is not None:
                    c_socket.sendall(r.klv())

    def _server_loop(self, host, port):
        # This is for minimum info exchange during initialization
        # a single connection allows to implement locking mechanics easily
        with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
            s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
            s.bind((host, port))
            s.listen()
            s.settimeout(0.5)
            while self._run.value == 1:
                try:
                    c_socket, addr = s.accept()
                except socket.timeout:
                    continue

                t = threading.Thread(
                    target=self._process_request,
                    args=(c_socket,), daemon=True)
                t.start()

    def run(self, host=_DEFAULT_HOST, port=_DEFAULT_PORT):
        # Run the TCP store in a different process
        p = ExceptionAwareProcess(
            target=self._server_loop, args=(host, port))
        p.start()
        self._process = p

    def stop(self):
        if _exit_mode:
            return  # Prevent shutdown errors
        if self._process is not None:
            with self._run.get_lock():
                self._run.value = 0
            if self._process.is_alive():
                self._process.join()


class TCPStoreProxy:

    MAX_NUM_RETRIES = 50
    DELAY_FOR_RETRY = 0.5

    def __init__(self, host=_DEFAULT_HOST, port=_DEFAULT_PORT):
        self.host = host
        self.port = port

    def _send_recv(self, action):
        # Retry several times in case the rank 0 has not established the
        # main store yet
        for i in range(TCPStoreProxy.MAX_NUM_RETRIES):
            try:
                with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
                    # TODO retry connects
                    s.connect((self.host, self.port))
                    s.sendall(action.klv())
                    result_bytes = s.recv(sizeof(
                        _klv_utils.result_action_t))
                    if len(result_bytes) > 0:
                        result = _klv_utils.result_action_t.from_buffer_copy(
                            result_bytes)
                        value = bytearray(result.value)[:result.length]
                        if result.status == 0:
                            return action.decode_result(value)
                        else:
                            raise RuntimeError(value.decode('utf-8'))
            except ConnectionRefusedError:
                time.sleep(TCPStoreProxy.DELAY_FOR_RETRY)
        raise RuntimeError('TCPStore is not available')

    def __getitem__(self, key):
        return self._send_recv(_store_actions.Get(key))

    def __setitem__(self, key, value):
        self._send_recv(_store_actions.Set(key, value))

    def barrier(self):
        # Barrier has special semantics
        self._send_recv(_store_actions.Barrier())