manager.py 14.7 KB
Newer Older
PengGao's avatar
PengGao committed
1
import gc
2
import queue
PengGao's avatar
PengGao committed
3
import threading
4
5
import time
from collections import OrderedDict
gushiqiao's avatar
gushiqiao committed
6

PengGao's avatar
PengGao committed
7
8
9
import torch
from loguru import logger

gushiqiao's avatar
gushiqiao committed
10

11
class WeightAsyncStreamManager(object):
12
    def __init__(self, blocks_num, offload_ratio=1, phases_num=1):
13
        self.init(blocks_num, phases_num, offload_ratio)
gushiqiao's avatar
gushiqiao committed
14
        self.compute_stream = torch.cuda.Stream(priority=-1)
15
16
        self.cpu_load_stream = torch.cuda.Stream(priority=0)
        self.cuda_load_stream = torch.cuda.Stream(priority=0)
17
18
19
20
21
22

    def init(self, blocks_num, phases_num, offload_ratio):
        if hasattr(self, "active_weights"):
            del self.active_weights[:]
        self.active_weights = [None for _ in range(3)]
        self.blocks_num = blocks_num
23
        self.phases_num = phases_num
24
25
26
        self.offload_ratio = offload_ratio
        self.offload_blocks_num = int(self.offload_ratio * self.blocks_num)
        self.offload_phases_num = self.blocks_num * self.phases_num * self.offload_ratio
gushiqiao's avatar
gushiqiao committed
27
28

    def prefetch_weights(self, block_idx, blocks_weights):
29
30
31
32
        with torch.cuda.stream(self.cuda_load_stream):
            self.active_weights[2] = blocks_weights[block_idx]
            self.active_weights[2].to_cuda_async()
        with torch.cuda.stream(self.cpu_load_stream):
33
            if block_idx < self.offload_blocks_num:
34
35
                if self.active_weights[1] is not None:
                    self.active_weights[1].to_cpu_async()
gushiqiao's avatar
gushiqiao committed
36
37
38

    def swap_weights(self):
        self.compute_stream.synchronize()
39
40
        self.cpu_load_stream.synchronize()
        self.cuda_load_stream.synchronize()
gushiqiao's avatar
gushiqiao committed
41
42

        self.active_weights[0], self.active_weights[1] = (
43
            self.active_weights[2],
gushiqiao's avatar
gushiqiao committed
44
45
            self.active_weights[0],
        )
46
47

    def prefetch_phase(self, block_idx, phase_idx, blocks):
48
        with torch.cuda.stream(self.cuda_load_stream):
49
50
            new_phase = blocks[block_idx].compute_phases[phase_idx]
            new_phase.to_cuda_async()
51
52
53
54
55
56
            self.active_weights[2] = (phase_idx, blocks[block_idx].compute_phases[phase_idx])
        with torch.cuda.stream(self.cpu_load_stream):
            if block_idx * self.phases_num + phase_idx < self.offload_phases_num:
                if self.active_weights[1] is not None:
                    _, old_phase = self.active_weights[1]
                    old_phase.to_cpu_async()
57
58
59

    def swap_phases(self):
        self.compute_stream.synchronize()
60
61
62
        self.cpu_load_stream.synchronize()
        self.cuda_load_stream.synchronize()
        self.active_weights[0], self.active_weights[1] = self.active_weights[2], self.active_weights[0]
63
64
65
66
        self.active_weights[2] = None


class LazyWeightAsyncStreamManager(WeightAsyncStreamManager):
gushiqiao's avatar
gushiqiao committed
67
    def __init__(self, blocks_num, offload_ratio=1, phases_num=1, num_disk_workers=1, max_memory=2, offload_gra="phase"):
68
        super().__init__(blocks_num, offload_ratio, phases_num)
gushiqiao's avatar
gushiqiao committed
69
        self.offload_gra = offload_gra
70
71
72
73
74
75
76
77
78
79
80
81
82
83
        self.worker_stop_event = threading.Event()
        self.pin_memory_buffer = MemoryBuffer(max_memory * (1024**3))
        self.disk_task_queue = queue.PriorityQueue()
        self.disk_workers = []
        self.release_workers = []
        self._start_disk_workers(num_disk_workers)
        self.initial_prefetch_done = False
        self.pending_tasks = {}
        self.task_lock = threading.Lock()
        self.last_used_time = {}
        self.time_lock = threading.Lock()

    def _start_disk_workers(self, num_workers):
        for i in range(num_workers):
gushiqiao's avatar
gushiqiao committed
84
85
86
87
            if self.offload_gra == "phase":
                worker = threading.Thread(target=self._disk_worker_loop, daemon=True)
            else:
                worker = threading.Thread(target=self._disk_worker_loop_block, daemon=True)
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
            worker.start()
            self.disk_workers.append(worker)

    def _disk_worker_loop(self):
        while not self.worker_stop_event.is_set():
            try:
                _, task = self.disk_task_queue.get(timeout=0.5)
                if task is None:
                    break

                block_idx, phase_idx, phase = task

                phase.load_from_disk()
                self.pin_memory_buffer.push((block_idx, phase_idx), phase)

                with self.task_lock:
                    if (block_idx, phase_idx) in self.pending_tasks:
                        del self.pending_tasks[(block_idx, phase_idx)]
            except queue.Empty:
                continue
            except Exception as e:
                logger.error(f"Disk worker thread error: {e}")

gushiqiao's avatar
gushiqiao committed
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
    def _disk_worker_loop_block(self):
        while not self.worker_stop_event.is_set():
            try:
                _, task = self.disk_task_queue.get(timeout=0.5)
                if task is None:
                    break

                block_idx, block = task

                for phase in block.compute_phases:
                    phase.load_from_disk()
                self.pin_memory_buffer.push(block_idx, block)

                with self.task_lock:
                    if block_idx in self.pending_tasks:
                        del self.pending_tasks[block_idx]
            except queue.Empty:
                continue
            except Exception as e:
                logger.error(f"Disk worker thread error: {e}")

gushiqiao's avatar
gushiqiao committed
132
133
134
    def _async_prefetch_block(self, blocks, next_block_idx=None):
        if next_block_idx is None:
            next_block_idx = self.pin_memory_buffer.get_max_block_index()
gushiqiao's avatar
gushiqiao committed
135

136
137
138
        if next_block_idx < 0:
            next_block_idx = 0

139
        if next_block_idx == self.blocks_num:
gushiqiao's avatar
gushiqiao committed
140
141
            return

gushiqiao's avatar
gushiqiao committed
142
143
144
145
146
147
148
149
150
151
        if self.offload_gra == "phase":
            for phase_idx in range(self.phases_num):
                obj_key = (next_block_idx, phase_idx)

                if self.pin_memory_buffer.exists(obj_key) or (obj_key in self.pending_tasks):
                    continue

                with self.task_lock:
                    self.pending_tasks[obj_key] = True

gushiqiao's avatar
gushiqiao committed
152
                phase = blocks[next_block_idx].compute_phases[phase_idx]
153

gushiqiao's avatar
gushiqiao committed
154
155
156
157
                priority_key = (next_block_idx, phase_idx)
                self.disk_task_queue.put((priority_key, (next_block_idx, phase_idx, phase)))
        else:
            obj_key = next_block_idx
158
            if self.pin_memory_buffer.exists(obj_key) or (obj_key in self.pending_tasks):
gushiqiao's avatar
gushiqiao committed
159
                return
160
161
162
163

            with self.task_lock:
                self.pending_tasks[obj_key] = True

gushiqiao's avatar
gushiqiao committed
164
            block = blocks[next_block_idx]
gushiqiao's avatar
gushiqiao committed
165
            self.disk_task_queue.put((obj_key, (next_block_idx, block)))
166

gushiqiao's avatar
gushiqiao committed
167
    def _sync_prefetch_block(self, blocks):
168
169
        block_idx = 0
        while not self.pin_memory_buffer.is_nearly_full():
gushiqiao's avatar
gushiqiao committed
170
171
            if self.offload_gra == "phase":
                for phase_idx in range(self.phases_num):
gushiqiao's avatar
gushiqiao committed
172
                    phase = blocks[block_idx].compute_phases[phase_idx]
gushiqiao's avatar
gushiqiao committed
173
174
175
176
                    logger.info(f"Synchronous loading: block={block_idx}, phase={phase_idx}")
                    phase.load_from_disk()
                    self.pin_memory_buffer.push((block_idx, phase_idx), phase)
            else:
gushiqiao's avatar
gushiqiao committed
177
                block = blocks[block_idx]
gushiqiao's avatar
gushiqiao committed
178
179
180
181
182
                logger.info(f"Synchronous loading: block={block_idx}")
                for phase in block.compute_phases:
                    phase.load_from_disk()
                self.pin_memory_buffer.push(block_idx, block)

183
            block_idx += 1
184
            if block_idx == self.blocks_num:
gushiqiao's avatar
gushiqiao committed
185
                break
186

gushiqiao's avatar
gushiqiao committed
187
    def prefetch_weights_from_disk(self, blocks):
188
189
190
        if self.initial_prefetch_done:
            return

gushiqiao's avatar
gushiqiao committed
191
        self._sync_prefetch_block(blocks)
192
193
        self.initial_prefetch_done = True

gushiqiao's avatar
gushiqiao committed
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
    def prefetch_weights(self, block_idx, blocks):
        obj_key = block_idx

        if not self.pin_memory_buffer.exists(obj_key):
            is_loading = False
            with self.task_lock:
                if obj_key in self.pending_tasks:
                    is_loading = True

            if is_loading:
                start_time = time.time()
                while not self.pin_memory_buffer.exists(obj_key):
                    time.sleep(0.001)
                    if time.time() - start_time > 5:
                        raise TimeoutError(f"Load timeout: block={block_idx}")
            else:
gushiqiao's avatar
gushiqiao committed
210
211
212
213
214
215
216
217
218
                logger.info("Not find prefetch block={block_idx} task.")
                logger.info("Sync prefetch block={block_idx}.")
                self._async_prefetch_block(blocks, block_idx)
                start_time = time.time()
                for phase_idx in self.phases_num:
                    while not self.pin_memory_buffer.exists((block_idx, phase_idx)):
                        time.sleep(0.001)
                        if time.time() - start_time > 15:
                            raise TimeoutError(f"Load timeout: block={block_idx}, phase={phase_idx}")
gushiqiao's avatar
gushiqiao committed
219
220
221
222
223
224
225

        with torch.cuda.stream(self.cuda_load_stream):
            block = self.pin_memory_buffer.get(obj_key)
            block.to_cuda_async()
            self.active_weights[2] = (obj_key, block)

        with torch.cuda.stream(self.cpu_load_stream):
226
            if block_idx < self.offload_blocks_num:
gushiqiao's avatar
gushiqiao committed
227
228
229
230
231
232
                if self.active_weights[1] is not None:
                    old_key, old_block = self.active_weights[1]
                    if self.pin_memory_buffer.exists(old_key):
                        old_block.to_cpu_async()
                        self.pin_memory_buffer.pop(old_key)

233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
    def prefetch_phase(self, block_idx, phase_idx, blocks):
        obj_key = (block_idx, phase_idx)

        if not self.pin_memory_buffer.exists(obj_key):
            is_loading = False
            with self.task_lock:
                if obj_key in self.pending_tasks:
                    is_loading = True

            if is_loading:
                start_time = time.time()
                while not self.pin_memory_buffer.exists(obj_key):
                    time.sleep(0.001)
                    if time.time() - start_time > 5:
                        raise TimeoutError(f"Load timeout: block={block_idx}, phase={phase_idx}")
            else:
gushiqiao's avatar
gushiqiao committed
249
250
251
252
253
254
255
256
                logger.info(f"Not find block={block_idx}, phase={phase_idx} task.")
                logger.info(f"Sync prefetch block={block_idx}, phase={phase_idx}.")
                self._async_prefetch_block(blocks, block_idx)
                start_time = time.time()
                while not self.pin_memory_buffer.exists((block_idx, phase_idx)):
                    time.sleep(0.001)
                    if time.time() - start_time > 5:
                        raise TimeoutError(f"Load timeout: block={block_idx}, phase={phase_idx}")
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290

        with torch.cuda.stream(self.cuda_load_stream):
            phase = self.pin_memory_buffer.get(obj_key)
            phase.to_cuda_async()
            self.active_weights[2] = (obj_key, phase)

        with torch.cuda.stream(self.cpu_load_stream):
            if block_idx * self.phases_num + phase_idx < self.offload_phases_num:
                if self.active_weights[1] is not None:
                    old_key, old_phase = self.active_weights[1]
                    if self.pin_memory_buffer.exists(old_key):
                        old_phase.to_cpu_async()
                        self.pin_memory_buffer.pop(old_key)

    def shutdown(self):
        self.worker_stop_event.set()

        while not self.disk_task_queue.empty():
            try:
                self.disk_task_queue.get_nowait()
            except queue.Empty:
                continue

        for _ in self.disk_workers:
            self.disk_task_queue.put((0, None))

        for worker in self.disk_workers:
            worker.join(timeout=5)

        for worker in self.release_workers:
            worker.join(timeout=5)

        logger.info("All worker threads have been closed")

gushiqiao's avatar
gushiqiao committed
291
292
293
294
    def clear(self):
        self.pin_memory_buffer.clear()
        self.shutdown()

295
296
297
298
299
300

class MemoryBuffer:
    def __init__(self, max_memory_bytes=8 * (1024**3)):
        self.cache = OrderedDict()
        self.max_mem = max_memory_bytes
        self.used_mem = 0
gushiqiao's avatar
gushiqiao committed
301
        self.obj_size_map = {}
302
303
304
305
        self.lock = threading.Lock()
        self.insertion_order = []
        self.insertion_index = 0

gushiqiao's avatar
gushiqiao committed
306
    def push(self, key, obj):
307
308
309
        with self.lock:
            if key in self.cache:
                return
gushiqiao's avatar
gushiqiao committed
310
311
312
313
314
315
316
317
318
319
320
321
322
            if hasattr(obj, "compute_phases"):
                obj_idx = key
                if len(self.obj_size_map) == 0:
                    _size = 0
                    for phase in obj.compute_phases:
                        _size += phase.calculate_size()
                    self.obj_size_map[0] = _size
                size = self.obj_size_map[0]
            else:
                _, obj_idx = key
                if obj_idx not in self.obj_size_map:
                    self.obj_size_map[obj_idx] = obj.calculate_size()
                size = self.obj_size_map[obj_idx]
323

gushiqiao's avatar
gushiqiao committed
324
            self.cache[key] = (size, obj, self.insertion_index)
325
326
327
328
329
330
            self.insertion_order.append((key, self.insertion_index))
            self.insertion_index += 1
            self.used_mem += size

    def _remove_key(self, key):
        if key in self.cache:
gushiqiao's avatar
gushiqiao committed
331
            size, obj, idx = self.cache.pop(key)
332
            try:
gushiqiao's avatar
gushiqiao committed
333
334
335
336
337
                if hasattr(obj, "compute_phases"):
                    for phase in obj.compute_phases:
                        phase.clear()
                else:
                    obj.clear()
338
            except Exception as e:
gushiqiao's avatar
gushiqiao committed
339
                logger.info(f"Error clearing obj: {e}")
340
341
342
343
344
345
346
            self.used_mem -= size

            self.insertion_order = [(k, i) for (k, i) in self.insertion_order if k != key]

    def get(self, key, default=None):
        with self.lock:
            if key in self.cache:
gushiqiao's avatar
gushiqiao committed
347
348
                size, obj, idx = self.cache[key]
                return obj
349
350
351
352
353
354
        return default

    def exists(self, key):
        with self.lock:
            return key in self.cache

gushiqiao's avatar
gushiqiao committed
355
356
357
358
359
360
361
362
    def pop_front(self):
        with self.lock:
            if not self.insertion_order:
                return False
            front_key, _ = self.insertion_order[0]
            self._remove_key(front_key)
            return True

363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
    def pop(self, key):
        with self.lock:
            if key in self.cache:
                self._remove_key(key)
                return True
        return False

    def is_nearly_full(self):
        with self.lock:
            return self.used_mem >= self.max_mem * 0.9

    def get_max_block_index(self):
        with self.lock:
            if not self.cache:
                return -1
gushiqiao's avatar
gushiqiao committed
378
379
380
381
            if isinstance(list(self.cache.keys())[-1], tuple):
                return (list(self.cache.keys())[-1][0] + 1) % 40
            else:
                return (list(self.cache.keys())[-1] + 1) % 40
gushiqiao's avatar
gushiqiao committed
382
383
384
385
386
387
388
389
390
391

    def clear(self):
        with self.lock:
            for key in list(self.cache.keys()):
                self._remove_key(key)

            self.insertion_order = []
            self.insertion_index = 0
            self.used_mem = 0
            torch.cuda.empty_cache()
gushiqiao's avatar
gushiqiao committed
392
            gc.collect()