manager.py 14.4 KB
Newer Older
gushiqiao's avatar
gushiqiao committed
1
import torch
2
3
4
import threading
import queue
import time
gushiqiao's avatar
gushiqiao committed
5
import gc
6
7
from loguru import logger
from collections import OrderedDict
gushiqiao's avatar
gushiqiao committed
8
9


10
class WeightAsyncStreamManager(object):
11
12
    def __init__(self, blocks_num, offload_ratio=1, phases_num=1):
        self.active_weights = [None for _ in range(3)]
gushiqiao's avatar
gushiqiao committed
13
        self.compute_stream = torch.cuda.Stream(priority=-1)
14
15
        self.cpu_load_stream = torch.cuda.Stream(priority=0)
        self.cuda_load_stream = torch.cuda.Stream(priority=0)
16
        self.offload_block_num = int(offload_ratio * blocks_num)
17
        self.phases_num = phases_num
gushiqiao's avatar
gushiqiao committed
18
        self.block_nums = blocks_num
19
        self.offload_phases_num = blocks_num * phases_num * offload_ratio
gushiqiao's avatar
gushiqiao committed
20
21

    def prefetch_weights(self, block_idx, blocks_weights):
22
23
24
25
26
27
28
        with torch.cuda.stream(self.cuda_load_stream):
            self.active_weights[2] = blocks_weights[block_idx]
            self.active_weights[2].to_cuda_async()
        with torch.cuda.stream(self.cpu_load_stream):
            if block_idx < self.offload_block_num:
                if self.active_weights[1] is not None:
                    self.active_weights[1].to_cpu_async()
gushiqiao's avatar
gushiqiao committed
29
30
31

    def swap_weights(self):
        self.compute_stream.synchronize()
32
33
        self.cpu_load_stream.synchronize()
        self.cuda_load_stream.synchronize()
gushiqiao's avatar
gushiqiao committed
34
35

        self.active_weights[0], self.active_weights[1] = (
36
            self.active_weights[2],
gushiqiao's avatar
gushiqiao committed
37
38
            self.active_weights[0],
        )
39
40

    def prefetch_phase(self, block_idx, phase_idx, blocks):
41
        with torch.cuda.stream(self.cuda_load_stream):
42
43
            new_phase = blocks[block_idx].compute_phases[phase_idx]
            new_phase.to_cuda_async()
44
45
46
47
48
49
            self.active_weights[2] = (phase_idx, blocks[block_idx].compute_phases[phase_idx])
        with torch.cuda.stream(self.cpu_load_stream):
            if block_idx * self.phases_num + phase_idx < self.offload_phases_num:
                if self.active_weights[1] is not None:
                    _, old_phase = self.active_weights[1]
                    old_phase.to_cpu_async()
50
51
52

    def swap_phases(self):
        self.compute_stream.synchronize()
53
54
55
        self.cpu_load_stream.synchronize()
        self.cuda_load_stream.synchronize()
        self.active_weights[0], self.active_weights[1] = self.active_weights[2], self.active_weights[0]
56
57
58
59
        self.active_weights[2] = None


class LazyWeightAsyncStreamManager(WeightAsyncStreamManager):
gushiqiao's avatar
gushiqiao committed
60
    def __init__(self, blocks_num, offload_ratio=1, phases_num=1, num_disk_workers=1, max_memory=2, offload_gra="phase"):
61
        super().__init__(blocks_num, offload_ratio, phases_num)
gushiqiao's avatar
gushiqiao committed
62
        self.offload_gra = offload_gra
63
64
65
66
67
68
69
70
71
72
73
74
75
76
        self.worker_stop_event = threading.Event()
        self.pin_memory_buffer = MemoryBuffer(max_memory * (1024**3))
        self.disk_task_queue = queue.PriorityQueue()
        self.disk_workers = []
        self.release_workers = []
        self._start_disk_workers(num_disk_workers)
        self.initial_prefetch_done = False
        self.pending_tasks = {}
        self.task_lock = threading.Lock()
        self.last_used_time = {}
        self.time_lock = threading.Lock()

    def _start_disk_workers(self, num_workers):
        for i in range(num_workers):
gushiqiao's avatar
gushiqiao committed
77
78
79
80
            if self.offload_gra == "phase":
                worker = threading.Thread(target=self._disk_worker_loop, daemon=True)
            else:
                worker = threading.Thread(target=self._disk_worker_loop_block, daemon=True)
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
            worker.start()
            self.disk_workers.append(worker)

    def _disk_worker_loop(self):
        while not self.worker_stop_event.is_set():
            try:
                _, task = self.disk_task_queue.get(timeout=0.5)
                if task is None:
                    break

                block_idx, phase_idx, phase = task

                phase.load_from_disk()
                self.pin_memory_buffer.push((block_idx, phase_idx), phase)

                with self.task_lock:
                    if (block_idx, phase_idx) in self.pending_tasks:
                        del self.pending_tasks[(block_idx, phase_idx)]
            except queue.Empty:
                continue
            except Exception as e:
                logger.error(f"Disk worker thread error: {e}")

gushiqiao's avatar
gushiqiao committed
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
    def _disk_worker_loop_block(self):
        while not self.worker_stop_event.is_set():
            try:
                _, task = self.disk_task_queue.get(timeout=0.5)
                if task is None:
                    break

                block_idx, block = task

                for phase in block.compute_phases:
                    phase.load_from_disk()
                self.pin_memory_buffer.push(block_idx, block)

                with self.task_lock:
                    if block_idx in self.pending_tasks:
                        del self.pending_tasks[block_idx]
            except queue.Empty:
                continue
            except Exception as e:
                logger.error(f"Disk worker thread error: {e}")

gushiqiao's avatar
gushiqiao committed
125
126
127
    def _async_prefetch_block(self, blocks, next_block_idx=None):
        if next_block_idx is None:
            next_block_idx = self.pin_memory_buffer.get_max_block_index()
gushiqiao's avatar
gushiqiao committed
128

129
130
131
        if next_block_idx < 0:
            next_block_idx = 0

gushiqiao's avatar
gushiqiao committed
132
133
134
        if next_block_idx == self.block_nums:
            return

gushiqiao's avatar
gushiqiao committed
135
136
137
138
139
140
141
142
143
144
        if self.offload_gra == "phase":
            for phase_idx in range(self.phases_num):
                obj_key = (next_block_idx, phase_idx)

                if self.pin_memory_buffer.exists(obj_key) or (obj_key in self.pending_tasks):
                    continue

                with self.task_lock:
                    self.pending_tasks[obj_key] = True

gushiqiao's avatar
gushiqiao committed
145
                phase = blocks[next_block_idx].compute_phases[phase_idx]
146

gushiqiao's avatar
gushiqiao committed
147
148
149
150
                priority_key = (next_block_idx, phase_idx)
                self.disk_task_queue.put((priority_key, (next_block_idx, phase_idx, phase)))
        else:
            obj_key = next_block_idx
151
            if self.pin_memory_buffer.exists(obj_key) or (obj_key in self.pending_tasks):
gushiqiao's avatar
gushiqiao committed
152
                return
153
154
155
156

            with self.task_lock:
                self.pending_tasks[obj_key] = True

gushiqiao's avatar
gushiqiao committed
157
            block = blocks[next_block_idx]
gushiqiao's avatar
gushiqiao committed
158
            self.disk_task_queue.put((obj_key, (next_block_idx, block)))
159

gushiqiao's avatar
gushiqiao committed
160
    def _sync_prefetch_block(self, blocks):
161
162
        block_idx = 0
        while not self.pin_memory_buffer.is_nearly_full():
gushiqiao's avatar
gushiqiao committed
163
164
            if self.offload_gra == "phase":
                for phase_idx in range(self.phases_num):
gushiqiao's avatar
gushiqiao committed
165
                    phase = blocks[block_idx].compute_phases[phase_idx]
gushiqiao's avatar
gushiqiao committed
166
167
168
169
                    logger.info(f"Synchronous loading: block={block_idx}, phase={phase_idx}")
                    phase.load_from_disk()
                    self.pin_memory_buffer.push((block_idx, phase_idx), phase)
            else:
gushiqiao's avatar
gushiqiao committed
170
                block = blocks[block_idx]
gushiqiao's avatar
gushiqiao committed
171
172
173
174
175
                logger.info(f"Synchronous loading: block={block_idx}")
                for phase in block.compute_phases:
                    phase.load_from_disk()
                self.pin_memory_buffer.push(block_idx, block)

176
            block_idx += 1
gushiqiao's avatar
gushiqiao committed
177
178
            if block_idx == self.block_nums:
                break
179

gushiqiao's avatar
gushiqiao committed
180
    def prefetch_weights_from_disk(self, blocks):
181
182
183
        if self.initial_prefetch_done:
            return

gushiqiao's avatar
gushiqiao committed
184
        self._sync_prefetch_block(blocks)
185
186
        self.initial_prefetch_done = True

gushiqiao's avatar
gushiqiao committed
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
    def prefetch_weights(self, block_idx, blocks):
        obj_key = block_idx

        if not self.pin_memory_buffer.exists(obj_key):
            is_loading = False
            with self.task_lock:
                if obj_key in self.pending_tasks:
                    is_loading = True

            if is_loading:
                start_time = time.time()
                while not self.pin_memory_buffer.exists(obj_key):
                    time.sleep(0.001)
                    if time.time() - start_time > 5:
                        raise TimeoutError(f"Load timeout: block={block_idx}")
            else:
gushiqiao's avatar
gushiqiao committed
203
204
205
206
207
208
209
210
211
                logger.info("Not find prefetch block={block_idx} task.")
                logger.info("Sync prefetch block={block_idx}.")
                self._async_prefetch_block(blocks, block_idx)
                start_time = time.time()
                for phase_idx in self.phases_num:
                    while not self.pin_memory_buffer.exists((block_idx, phase_idx)):
                        time.sleep(0.001)
                        if time.time() - start_time > 15:
                            raise TimeoutError(f"Load timeout: block={block_idx}, phase={phase_idx}")
gushiqiao's avatar
gushiqiao committed
212
213
214
215
216
217
218
219
220
221
222
223
224
225

        with torch.cuda.stream(self.cuda_load_stream):
            block = self.pin_memory_buffer.get(obj_key)
            block.to_cuda_async()
            self.active_weights[2] = (obj_key, block)

        with torch.cuda.stream(self.cpu_load_stream):
            if block_idx < self.offload_block_num:
                if self.active_weights[1] is not None:
                    old_key, old_block = self.active_weights[1]
                    if self.pin_memory_buffer.exists(old_key):
                        old_block.to_cpu_async()
                        self.pin_memory_buffer.pop(old_key)

226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
    def prefetch_phase(self, block_idx, phase_idx, blocks):
        obj_key = (block_idx, phase_idx)

        if not self.pin_memory_buffer.exists(obj_key):
            is_loading = False
            with self.task_lock:
                if obj_key in self.pending_tasks:
                    is_loading = True

            if is_loading:
                start_time = time.time()
                while not self.pin_memory_buffer.exists(obj_key):
                    time.sleep(0.001)
                    if time.time() - start_time > 5:
                        raise TimeoutError(f"Load timeout: block={block_idx}, phase={phase_idx}")
            else:
gushiqiao's avatar
gushiqiao committed
242
243
244
245
246
247
248
249
                logger.info(f"Not find block={block_idx}, phase={phase_idx} task.")
                logger.info(f"Sync prefetch block={block_idx}, phase={phase_idx}.")
                self._async_prefetch_block(blocks, block_idx)
                start_time = time.time()
                while not self.pin_memory_buffer.exists((block_idx, phase_idx)):
                    time.sleep(0.001)
                    if time.time() - start_time > 5:
                        raise TimeoutError(f"Load timeout: block={block_idx}, phase={phase_idx}")
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283

        with torch.cuda.stream(self.cuda_load_stream):
            phase = self.pin_memory_buffer.get(obj_key)
            phase.to_cuda_async()
            self.active_weights[2] = (obj_key, phase)

        with torch.cuda.stream(self.cpu_load_stream):
            if block_idx * self.phases_num + phase_idx < self.offload_phases_num:
                if self.active_weights[1] is not None:
                    old_key, old_phase = self.active_weights[1]
                    if self.pin_memory_buffer.exists(old_key):
                        old_phase.to_cpu_async()
                        self.pin_memory_buffer.pop(old_key)

    def shutdown(self):
        self.worker_stop_event.set()

        while not self.disk_task_queue.empty():
            try:
                self.disk_task_queue.get_nowait()
            except queue.Empty:
                continue

        for _ in self.disk_workers:
            self.disk_task_queue.put((0, None))

        for worker in self.disk_workers:
            worker.join(timeout=5)

        for worker in self.release_workers:
            worker.join(timeout=5)

        logger.info("All worker threads have been closed")

gushiqiao's avatar
gushiqiao committed
284
285
286
287
    def clear(self):
        self.pin_memory_buffer.clear()
        self.shutdown()

288
289
290
291
292
293

class MemoryBuffer:
    def __init__(self, max_memory_bytes=8 * (1024**3)):
        self.cache = OrderedDict()
        self.max_mem = max_memory_bytes
        self.used_mem = 0
gushiqiao's avatar
gushiqiao committed
294
        self.obj_size_map = {}
295
296
297
298
        self.lock = threading.Lock()
        self.insertion_order = []
        self.insertion_index = 0

gushiqiao's avatar
gushiqiao committed
299
    def push(self, key, obj):
300
301
302
        with self.lock:
            if key in self.cache:
                return
gushiqiao's avatar
gushiqiao committed
303
304
305
306
307
308
309
310
311
312
313
314
315
            if hasattr(obj, "compute_phases"):
                obj_idx = key
                if len(self.obj_size_map) == 0:
                    _size = 0
                    for phase in obj.compute_phases:
                        _size += phase.calculate_size()
                    self.obj_size_map[0] = _size
                size = self.obj_size_map[0]
            else:
                _, obj_idx = key
                if obj_idx not in self.obj_size_map:
                    self.obj_size_map[obj_idx] = obj.calculate_size()
                size = self.obj_size_map[obj_idx]
316

gushiqiao's avatar
gushiqiao committed
317
            self.cache[key] = (size, obj, self.insertion_index)
318
319
320
321
322
323
            self.insertion_order.append((key, self.insertion_index))
            self.insertion_index += 1
            self.used_mem += size

    def _remove_key(self, key):
        if key in self.cache:
gushiqiao's avatar
gushiqiao committed
324
            size, obj, idx = self.cache.pop(key)
325
            try:
gushiqiao's avatar
gushiqiao committed
326
327
328
329
330
                if hasattr(obj, "compute_phases"):
                    for phase in obj.compute_phases:
                        phase.clear()
                else:
                    obj.clear()
331
            except Exception as e:
gushiqiao's avatar
gushiqiao committed
332
                logger.info(f"Error clearing obj: {e}")
333
334
335
336
337
338
339
            self.used_mem -= size

            self.insertion_order = [(k, i) for (k, i) in self.insertion_order if k != key]

    def get(self, key, default=None):
        with self.lock:
            if key in self.cache:
gushiqiao's avatar
gushiqiao committed
340
341
                size, obj, idx = self.cache[key]
                return obj
342
343
344
345
346
347
        return default

    def exists(self, key):
        with self.lock:
            return key in self.cache

gushiqiao's avatar
gushiqiao committed
348
349
350
351
352
353
354
355
    def pop_front(self):
        with self.lock:
            if not self.insertion_order:
                return False
            front_key, _ = self.insertion_order[0]
            self._remove_key(front_key)
            return True

356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
    def pop(self, key):
        with self.lock:
            if key in self.cache:
                self._remove_key(key)
                return True
        return False

    def is_nearly_full(self):
        with self.lock:
            return self.used_mem >= self.max_mem * 0.9

    def get_max_block_index(self):
        with self.lock:
            if not self.cache:
                return -1
gushiqiao's avatar
gushiqiao committed
371
372
373
374
            if isinstance(list(self.cache.keys())[-1], tuple):
                return (list(self.cache.keys())[-1][0] + 1) % 40
            else:
                return (list(self.cache.keys())[-1] + 1) % 40
gushiqiao's avatar
gushiqiao committed
375
376
377
378
379
380
381
382
383
384

    def clear(self):
        with self.lock:
            for key in list(self.cache.keys()):
                self._remove_key(key)

            self.insertion_order = []
            self.insertion_index = 0
            self.used_mem = 0
            torch.cuda.empty_cache()
gushiqiao's avatar
gushiqiao committed
385
            gc.collect()