manager.py 13.3 KB
Newer Older
gushiqiao's avatar
gushiqiao committed
1
import torch
2
3
4
import threading
import queue
import time
gushiqiao's avatar
gushiqiao committed
5
import gc
6
7
from loguru import logger
from collections import OrderedDict
gushiqiao's avatar
gushiqiao committed
8
9


10
class WeightAsyncStreamManager(object):
11
12
    def __init__(self, blocks_num, offload_ratio=1, phases_num=1):
        self.active_weights = [None for _ in range(3)]
gushiqiao's avatar
gushiqiao committed
13
        self.compute_stream = torch.cuda.Stream(priority=-1)
14
15
        self.cpu_load_stream = torch.cuda.Stream(priority=0)
        self.cuda_load_stream = torch.cuda.Stream(priority=0)
16
        self.offload_block_num = int(offload_ratio * blocks_num)
17
18
        self.phases_num = phases_num
        self.offload_phases_num = blocks_num * phases_num * offload_ratio
gushiqiao's avatar
gushiqiao committed
19
20

    def prefetch_weights(self, block_idx, blocks_weights):
21
22
23
24
25
26
27
        with torch.cuda.stream(self.cuda_load_stream):
            self.active_weights[2] = blocks_weights[block_idx]
            self.active_weights[2].to_cuda_async()
        with torch.cuda.stream(self.cpu_load_stream):
            if block_idx < self.offload_block_num:
                if self.active_weights[1] is not None:
                    self.active_weights[1].to_cpu_async()
gushiqiao's avatar
gushiqiao committed
28
29
30

    def swap_weights(self):
        self.compute_stream.synchronize()
31
32
        self.cpu_load_stream.synchronize()
        self.cuda_load_stream.synchronize()
gushiqiao's avatar
gushiqiao committed
33
34

        self.active_weights[0], self.active_weights[1] = (
35
            self.active_weights[2],
gushiqiao's avatar
gushiqiao committed
36
37
            self.active_weights[0],
        )
38
39

    def prefetch_phase(self, block_idx, phase_idx, blocks):
40
        with torch.cuda.stream(self.cuda_load_stream):
41
42
            new_phase = blocks[block_idx].compute_phases[phase_idx]
            new_phase.to_cuda_async()
43
44
45
46
47
48
            self.active_weights[2] = (phase_idx, blocks[block_idx].compute_phases[phase_idx])
        with torch.cuda.stream(self.cpu_load_stream):
            if block_idx * self.phases_num + phase_idx < self.offload_phases_num:
                if self.active_weights[1] is not None:
                    _, old_phase = self.active_weights[1]
                    old_phase.to_cpu_async()
49
50
51

    def swap_phases(self):
        self.compute_stream.synchronize()
52
53
54
        self.cpu_load_stream.synchronize()
        self.cuda_load_stream.synchronize()
        self.active_weights[0], self.active_weights[1] = self.active_weights[2], self.active_weights[0]
55
56
57
58
        self.active_weights[2] = None


class LazyWeightAsyncStreamManager(WeightAsyncStreamManager):
gushiqiao's avatar
gushiqiao committed
59
    def __init__(self, blocks_num, offload_ratio=1, phases_num=1, num_disk_workers=1, max_memory=2, offload_gra="phase"):
60
        super().__init__(blocks_num, offload_ratio, phases_num)
gushiqiao's avatar
gushiqiao committed
61
        self.offload_gra = offload_gra
62
63
64
65
66
67
68
69
70
71
72
73
74
75
        self.worker_stop_event = threading.Event()
        self.pin_memory_buffer = MemoryBuffer(max_memory * (1024**3))
        self.disk_task_queue = queue.PriorityQueue()
        self.disk_workers = []
        self.release_workers = []
        self._start_disk_workers(num_disk_workers)
        self.initial_prefetch_done = False
        self.pending_tasks = {}
        self.task_lock = threading.Lock()
        self.last_used_time = {}
        self.time_lock = threading.Lock()

    def _start_disk_workers(self, num_workers):
        for i in range(num_workers):
gushiqiao's avatar
gushiqiao committed
76
77
78
79
            if self.offload_gra == "phase":
                worker = threading.Thread(target=self._disk_worker_loop, daemon=True)
            else:
                worker = threading.Thread(target=self._disk_worker_loop_block, daemon=True)
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
            worker.start()
            self.disk_workers.append(worker)

    def _disk_worker_loop(self):
        while not self.worker_stop_event.is_set():
            try:
                _, task = self.disk_task_queue.get(timeout=0.5)
                if task is None:
                    break

                block_idx, phase_idx, phase = task

                phase.load_from_disk()
                self.pin_memory_buffer.push((block_idx, phase_idx), phase)

                with self.task_lock:
                    if (block_idx, phase_idx) in self.pending_tasks:
                        del self.pending_tasks[(block_idx, phase_idx)]
            except queue.Empty:
                continue
            except Exception as e:
                logger.error(f"Disk worker thread error: {e}")

gushiqiao's avatar
gushiqiao committed
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
    def _disk_worker_loop_block(self):
        while not self.worker_stop_event.is_set():
            try:
                _, task = self.disk_task_queue.get(timeout=0.5)
                if task is None:
                    break

                block_idx, block = task

                for phase in block.compute_phases:
                    phase.load_from_disk()
                self.pin_memory_buffer.push(block_idx, block)

                with self.task_lock:
                    if block_idx in self.pending_tasks:
                        del self.pending_tasks[block_idx]
            except queue.Empty:
                continue
            except Exception as e:
                logger.error(f"Disk worker thread error: {e}")

124
125
    def _async_prefetch_block(self, weights):
        next_block_idx = self.pin_memory_buffer.get_max_block_index()
gushiqiao's avatar
gushiqiao committed
126

127
128
129
        if next_block_idx < 0:
            next_block_idx = 0

gushiqiao's avatar
gushiqiao committed
130
131
132
133
134
135
136
137
138
139
140
        if self.offload_gra == "phase":
            for phase_idx in range(self.phases_num):
                obj_key = (next_block_idx, phase_idx)

                if self.pin_memory_buffer.exists(obj_key) or (obj_key in self.pending_tasks):
                    continue

                with self.task_lock:
                    self.pending_tasks[obj_key] = True

                phase = weights.blocks[next_block_idx].compute_phases[phase_idx]
141

gushiqiao's avatar
gushiqiao committed
142
143
144
145
                priority_key = (next_block_idx, phase_idx)
                self.disk_task_queue.put((priority_key, (next_block_idx, phase_idx, phase)))
        else:
            obj_key = next_block_idx
146
            if self.pin_memory_buffer.exists(obj_key) or (obj_key in self.pending_tasks):
gushiqiao's avatar
gushiqiao committed
147
                return
148
149
150
151

            with self.task_lock:
                self.pending_tasks[obj_key] = True

gushiqiao's avatar
gushiqiao committed
152
153
            block = weights.blocks[next_block_idx]
            self.disk_task_queue.put((obj_key, (next_block_idx, block)))
154
155
156
157

    def _sync_prefetch_block(self, weights):
        block_idx = 0
        while not self.pin_memory_buffer.is_nearly_full():
gushiqiao's avatar
gushiqiao committed
158
159
160
161
162
163
164
165
166
167
168
169
170
            if self.offload_gra == "phase":
                for phase_idx in range(self.phases_num):
                    phase = weights.blocks[block_idx].compute_phases[phase_idx]
                    logger.info(f"Synchronous loading: block={block_idx}, phase={phase_idx}")
                    phase.load_from_disk()
                    self.pin_memory_buffer.push((block_idx, phase_idx), phase)
            else:
                block = weights.blocks[block_idx]
                logger.info(f"Synchronous loading: block={block_idx}")
                for phase in block.compute_phases:
                    phase.load_from_disk()
                self.pin_memory_buffer.push(block_idx, block)

171
172
173
174
175
176
177
178
179
            block_idx += 1

    def prefetch_weights_from_disk(self, weights):
        if self.initial_prefetch_done:
            return

        self._sync_prefetch_block(weights)
        self.initial_prefetch_done = True

gushiqiao's avatar
gushiqiao committed
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
    def prefetch_weights(self, block_idx, blocks):
        obj_key = block_idx

        if not self.pin_memory_buffer.exists(obj_key):
            is_loading = False
            with self.task_lock:
                if obj_key in self.pending_tasks:
                    is_loading = True

            if is_loading:
                start_time = time.time()
                while not self.pin_memory_buffer.exists(obj_key):
                    time.sleep(0.001)
                    if time.time() - start_time > 5:
                        raise TimeoutError(f"Load timeout: block={block_idx}")
            else:
                logger.info("Not find prefetch block={block_idx} task. This is a bug.")

        with torch.cuda.stream(self.cuda_load_stream):
            block = self.pin_memory_buffer.get(obj_key)
            block.to_cuda_async()
            self.active_weights[2] = (obj_key, block)

        with torch.cuda.stream(self.cpu_load_stream):
            if block_idx < self.offload_block_num:
                if self.active_weights[1] is not None:
                    old_key, old_block = self.active_weights[1]
                    if self.pin_memory_buffer.exists(old_key):
                        old_block.to_cpu_async()
                        self.pin_memory_buffer.pop(old_key)

211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
    def prefetch_phase(self, block_idx, phase_idx, blocks):
        obj_key = (block_idx, phase_idx)

        if not self.pin_memory_buffer.exists(obj_key):
            is_loading = False
            with self.task_lock:
                if obj_key in self.pending_tasks:
                    is_loading = True

            if is_loading:
                start_time = time.time()
                while not self.pin_memory_buffer.exists(obj_key):
                    time.sleep(0.001)
                    if time.time() - start_time > 5:
                        raise TimeoutError(f"Load timeout: block={block_idx}, phase={phase_idx}")
            else:
                logger.info("Not find prefetch block={block_idx}, phase={phase_idx} task. This is a bug.")

        with torch.cuda.stream(self.cuda_load_stream):
            phase = self.pin_memory_buffer.get(obj_key)
            phase.to_cuda_async()
            self.active_weights[2] = (obj_key, phase)

        with torch.cuda.stream(self.cpu_load_stream):
            if block_idx * self.phases_num + phase_idx < self.offload_phases_num:
                if self.active_weights[1] is not None:
                    old_key, old_phase = self.active_weights[1]
                    if self.pin_memory_buffer.exists(old_key):
                        old_phase.to_cpu_async()
                        self.pin_memory_buffer.pop(old_key)

    def shutdown(self):
        self.worker_stop_event.set()

        while not self.disk_task_queue.empty():
            try:
                self.disk_task_queue.get_nowait()
            except queue.Empty:
                continue

        for _ in self.disk_workers:
            self.disk_task_queue.put((0, None))

        for worker in self.disk_workers:
            worker.join(timeout=5)

        for worker in self.release_workers:
            worker.join(timeout=5)

        logger.info("All worker threads have been closed")

gushiqiao's avatar
gushiqiao committed
262
263
264
265
    def clear(self):
        self.pin_memory_buffer.clear()
        self.shutdown()

266
267
268
269
270
271

class MemoryBuffer:
    def __init__(self, max_memory_bytes=8 * (1024**3)):
        self.cache = OrderedDict()
        self.max_mem = max_memory_bytes
        self.used_mem = 0
gushiqiao's avatar
gushiqiao committed
272
        self.obj_size_map = {}
273
274
275
276
        self.lock = threading.Lock()
        self.insertion_order = []
        self.insertion_index = 0

gushiqiao's avatar
gushiqiao committed
277
    def push(self, key, obj):
278
279
280
        with self.lock:
            if key in self.cache:
                return
gushiqiao's avatar
gushiqiao committed
281
282
283
284
285
286
287
288
289
290
291
292
293
            if hasattr(obj, "compute_phases"):
                obj_idx = key
                if len(self.obj_size_map) == 0:
                    _size = 0
                    for phase in obj.compute_phases:
                        _size += phase.calculate_size()
                    self.obj_size_map[0] = _size
                size = self.obj_size_map[0]
            else:
                _, obj_idx = key
                if obj_idx not in self.obj_size_map:
                    self.obj_size_map[obj_idx] = obj.calculate_size()
                size = self.obj_size_map[obj_idx]
294

gushiqiao's avatar
gushiqiao committed
295
            self.cache[key] = (size, obj, self.insertion_index)
296
297
298
299
300
301
            self.insertion_order.append((key, self.insertion_index))
            self.insertion_index += 1
            self.used_mem += size

    def _remove_key(self, key):
        if key in self.cache:
gushiqiao's avatar
gushiqiao committed
302
            size, obj, idx = self.cache.pop(key)
303
            try:
gushiqiao's avatar
gushiqiao committed
304
305
306
307
308
                if hasattr(obj, "compute_phases"):
                    for phase in obj.compute_phases:
                        phase.clear()
                else:
                    obj.clear()
309
            except Exception as e:
gushiqiao's avatar
gushiqiao committed
310
                logger.info(f"Error clearing obj: {e}")
311
312
313
314
315
316
317
            self.used_mem -= size

            self.insertion_order = [(k, i) for (k, i) in self.insertion_order if k != key]

    def get(self, key, default=None):
        with self.lock:
            if key in self.cache:
gushiqiao's avatar
gushiqiao committed
318
319
                size, obj, idx = self.cache[key]
                return obj
320
321
322
323
324
325
        return default

    def exists(self, key):
        with self.lock:
            return key in self.cache

gushiqiao's avatar
gushiqiao committed
326
327
328
329
330
331
332
333
    def pop_front(self):
        with self.lock:
            if not self.insertion_order:
                return False
            front_key, _ = self.insertion_order[0]
            self._remove_key(front_key)
            return True

334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
    def pop(self, key):
        with self.lock:
            if key in self.cache:
                self._remove_key(key)
                return True
        return False

    def is_nearly_full(self):
        with self.lock:
            return self.used_mem >= self.max_mem * 0.9

    def get_max_block_index(self):
        with self.lock:
            if not self.cache:
                return -1
gushiqiao's avatar
gushiqiao committed
349
350
351
352
            if isinstance(list(self.cache.keys())[-1], tuple):
                return (list(self.cache.keys())[-1][0] + 1) % 40
            else:
                return (list(self.cache.keys())[-1] + 1) % 40
gushiqiao's avatar
gushiqiao committed
353
354
355
356
357
358
359
360
361
362

    def clear(self):
        with self.lock:
            for key in list(self.cache.keys()):
                self._remove_key(key)

            self.insertion_order = []
            self.insertion_index = 0
            self.used_mem = 0
            torch.cuda.empty_cache()
gushiqiao's avatar
gushiqiao committed
363
            gc.collect()