"vscode:/vscode.git/clone" did not exist on "a19d47642e49814c3fba634cd50a4898ce1a5a57"
manager.py 14.2 KB
Newer Older
gushiqiao's avatar
gushiqiao committed
1
import torch
2
3
4
import threading
import queue
import time
gushiqiao's avatar
gushiqiao committed
5
import gc
6
7
from loguru import logger
from collections import OrderedDict
gushiqiao's avatar
gushiqiao committed
8
9


10
class WeightAsyncStreamManager(object):
11
12
    def __init__(self, blocks_num, offload_ratio=1, phases_num=1):
        self.active_weights = [None for _ in range(3)]
gushiqiao's avatar
gushiqiao committed
13
        self.compute_stream = torch.cuda.Stream(priority=-1)
14
15
        self.cpu_load_stream = torch.cuda.Stream(priority=0)
        self.cuda_load_stream = torch.cuda.Stream(priority=0)
16
        self.offload_block_num = int(offload_ratio * blocks_num)
17
18
        self.phases_num = phases_num
        self.offload_phases_num = blocks_num * phases_num * offload_ratio
gushiqiao's avatar
gushiqiao committed
19
20

    def prefetch_weights(self, block_idx, blocks_weights):
21
22
23
24
25
26
27
        with torch.cuda.stream(self.cuda_load_stream):
            self.active_weights[2] = blocks_weights[block_idx]
            self.active_weights[2].to_cuda_async()
        with torch.cuda.stream(self.cpu_load_stream):
            if block_idx < self.offload_block_num:
                if self.active_weights[1] is not None:
                    self.active_weights[1].to_cpu_async()
gushiqiao's avatar
gushiqiao committed
28
29
30

    def swap_weights(self):
        self.compute_stream.synchronize()
31
32
        self.cpu_load_stream.synchronize()
        self.cuda_load_stream.synchronize()
gushiqiao's avatar
gushiqiao committed
33
34

        self.active_weights[0], self.active_weights[1] = (
35
            self.active_weights[2],
gushiqiao's avatar
gushiqiao committed
36
37
            self.active_weights[0],
        )
38
39

    def prefetch_phase(self, block_idx, phase_idx, blocks):
40
        with torch.cuda.stream(self.cuda_load_stream):
41
42
            new_phase = blocks[block_idx].compute_phases[phase_idx]
            new_phase.to_cuda_async()
43
44
45
46
47
48
            self.active_weights[2] = (phase_idx, blocks[block_idx].compute_phases[phase_idx])
        with torch.cuda.stream(self.cpu_load_stream):
            if block_idx * self.phases_num + phase_idx < self.offload_phases_num:
                if self.active_weights[1] is not None:
                    _, old_phase = self.active_weights[1]
                    old_phase.to_cpu_async()
49
50
51

    def swap_phases(self):
        self.compute_stream.synchronize()
52
53
54
        self.cpu_load_stream.synchronize()
        self.cuda_load_stream.synchronize()
        self.active_weights[0], self.active_weights[1] = self.active_weights[2], self.active_weights[0]
55
56
57
58
        self.active_weights[2] = None


class LazyWeightAsyncStreamManager(WeightAsyncStreamManager):
gushiqiao's avatar
gushiqiao committed
59
    def __init__(self, blocks_num, offload_ratio=1, phases_num=1, num_disk_workers=1, max_memory=2, offload_gra="phase"):
60
        super().__init__(blocks_num, offload_ratio, phases_num)
gushiqiao's avatar
gushiqiao committed
61
        self.offload_gra = offload_gra
62
63
64
65
66
67
68
69
70
71
72
73
74
75
        self.worker_stop_event = threading.Event()
        self.pin_memory_buffer = MemoryBuffer(max_memory * (1024**3))
        self.disk_task_queue = queue.PriorityQueue()
        self.disk_workers = []
        self.release_workers = []
        self._start_disk_workers(num_disk_workers)
        self.initial_prefetch_done = False
        self.pending_tasks = {}
        self.task_lock = threading.Lock()
        self.last_used_time = {}
        self.time_lock = threading.Lock()

    def _start_disk_workers(self, num_workers):
        for i in range(num_workers):
gushiqiao's avatar
gushiqiao committed
76
77
78
79
            if self.offload_gra == "phase":
                worker = threading.Thread(target=self._disk_worker_loop, daemon=True)
            else:
                worker = threading.Thread(target=self._disk_worker_loop_block, daemon=True)
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
            worker.start()
            self.disk_workers.append(worker)

    def _disk_worker_loop(self):
        while not self.worker_stop_event.is_set():
            try:
                _, task = self.disk_task_queue.get(timeout=0.5)
                if task is None:
                    break

                block_idx, phase_idx, phase = task

                phase.load_from_disk()
                self.pin_memory_buffer.push((block_idx, phase_idx), phase)

                with self.task_lock:
                    if (block_idx, phase_idx) in self.pending_tasks:
                        del self.pending_tasks[(block_idx, phase_idx)]
            except queue.Empty:
                continue
            except Exception as e:
                logger.error(f"Disk worker thread error: {e}")

gushiqiao's avatar
gushiqiao committed
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
    def _disk_worker_loop_block(self):
        while not self.worker_stop_event.is_set():
            try:
                _, task = self.disk_task_queue.get(timeout=0.5)
                if task is None:
                    break

                block_idx, block = task

                for phase in block.compute_phases:
                    phase.load_from_disk()
                self.pin_memory_buffer.push(block_idx, block)

                with self.task_lock:
                    if block_idx in self.pending_tasks:
                        del self.pending_tasks[block_idx]
            except queue.Empty:
                continue
            except Exception as e:
                logger.error(f"Disk worker thread error: {e}")

gushiqiao's avatar
gushiqiao committed
124
125
126
    def _async_prefetch_block(self, blocks, next_block_idx=None):
        if next_block_idx is None:
            next_block_idx = self.pin_memory_buffer.get_max_block_index()
gushiqiao's avatar
gushiqiao committed
127

128
129
130
        if next_block_idx < 0:
            next_block_idx = 0

gushiqiao's avatar
gushiqiao committed
131
132
133
134
135
136
137
138
139
140
        if self.offload_gra == "phase":
            for phase_idx in range(self.phases_num):
                obj_key = (next_block_idx, phase_idx)

                if self.pin_memory_buffer.exists(obj_key) or (obj_key in self.pending_tasks):
                    continue

                with self.task_lock:
                    self.pending_tasks[obj_key] = True

gushiqiao's avatar
gushiqiao committed
141
                phase = blocks[next_block_idx].compute_phases[phase_idx]
142

gushiqiao's avatar
gushiqiao committed
143
144
145
146
                priority_key = (next_block_idx, phase_idx)
                self.disk_task_queue.put((priority_key, (next_block_idx, phase_idx, phase)))
        else:
            obj_key = next_block_idx
147
            if self.pin_memory_buffer.exists(obj_key) or (obj_key in self.pending_tasks):
gushiqiao's avatar
gushiqiao committed
148
                return
149
150
151
152

            with self.task_lock:
                self.pending_tasks[obj_key] = True

gushiqiao's avatar
gushiqiao committed
153
            block = blocks[next_block_idx]
gushiqiao's avatar
gushiqiao committed
154
            self.disk_task_queue.put((obj_key, (next_block_idx, block)))
155

gushiqiao's avatar
gushiqiao committed
156
    def _sync_prefetch_block(self, blocks):
157
158
        block_idx = 0
        while not self.pin_memory_buffer.is_nearly_full():
gushiqiao's avatar
gushiqiao committed
159
160
            if self.offload_gra == "phase":
                for phase_idx in range(self.phases_num):
gushiqiao's avatar
gushiqiao committed
161
                    phase = blocks[block_idx].compute_phases[phase_idx]
gushiqiao's avatar
gushiqiao committed
162
163
164
165
                    logger.info(f"Synchronous loading: block={block_idx}, phase={phase_idx}")
                    phase.load_from_disk()
                    self.pin_memory_buffer.push((block_idx, phase_idx), phase)
            else:
gushiqiao's avatar
gushiqiao committed
166
                block = blocks[block_idx]
gushiqiao's avatar
gushiqiao committed
167
168
169
170
171
                logger.info(f"Synchronous loading: block={block_idx}")
                for phase in block.compute_phases:
                    phase.load_from_disk()
                self.pin_memory_buffer.push(block_idx, block)

172
173
            block_idx += 1

gushiqiao's avatar
gushiqiao committed
174
    def prefetch_weights_from_disk(self, blocks):
175
176
177
        if self.initial_prefetch_done:
            return

gushiqiao's avatar
gushiqiao committed
178
        self._sync_prefetch_block(blocks)
179
180
        self.initial_prefetch_done = True

gushiqiao's avatar
gushiqiao committed
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
    def prefetch_weights(self, block_idx, blocks):
        obj_key = block_idx

        if not self.pin_memory_buffer.exists(obj_key):
            is_loading = False
            with self.task_lock:
                if obj_key in self.pending_tasks:
                    is_loading = True

            if is_loading:
                start_time = time.time()
                while not self.pin_memory_buffer.exists(obj_key):
                    time.sleep(0.001)
                    if time.time() - start_time > 5:
                        raise TimeoutError(f"Load timeout: block={block_idx}")
            else:
gushiqiao's avatar
gushiqiao committed
197
198
199
200
201
202
203
204
205
                logger.info("Not find prefetch block={block_idx} task.")
                logger.info("Sync prefetch block={block_idx}.")
                self._async_prefetch_block(blocks, block_idx)
                start_time = time.time()
                for phase_idx in self.phases_num:
                    while not self.pin_memory_buffer.exists((block_idx, phase_idx)):
                        time.sleep(0.001)
                        if time.time() - start_time > 15:
                            raise TimeoutError(f"Load timeout: block={block_idx}, phase={phase_idx}")
gushiqiao's avatar
gushiqiao committed
206
207
208
209
210
211
212
213
214
215
216
217
218
219

        with torch.cuda.stream(self.cuda_load_stream):
            block = self.pin_memory_buffer.get(obj_key)
            block.to_cuda_async()
            self.active_weights[2] = (obj_key, block)

        with torch.cuda.stream(self.cpu_load_stream):
            if block_idx < self.offload_block_num:
                if self.active_weights[1] is not None:
                    old_key, old_block = self.active_weights[1]
                    if self.pin_memory_buffer.exists(old_key):
                        old_block.to_cpu_async()
                        self.pin_memory_buffer.pop(old_key)

220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
    def prefetch_phase(self, block_idx, phase_idx, blocks):
        obj_key = (block_idx, phase_idx)

        if not self.pin_memory_buffer.exists(obj_key):
            is_loading = False
            with self.task_lock:
                if obj_key in self.pending_tasks:
                    is_loading = True

            if is_loading:
                start_time = time.time()
                while not self.pin_memory_buffer.exists(obj_key):
                    time.sleep(0.001)
                    if time.time() - start_time > 5:
                        raise TimeoutError(f"Load timeout: block={block_idx}, phase={phase_idx}")
            else:
gushiqiao's avatar
gushiqiao committed
236
237
238
239
240
241
242
243
                logger.info(f"Not find block={block_idx}, phase={phase_idx} task.")
                logger.info(f"Sync prefetch block={block_idx}, phase={phase_idx}.")
                self._async_prefetch_block(blocks, block_idx)
                start_time = time.time()
                while not self.pin_memory_buffer.exists((block_idx, phase_idx)):
                    time.sleep(0.001)
                    if time.time() - start_time > 5:
                        raise TimeoutError(f"Load timeout: block={block_idx}, phase={phase_idx}")
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277

        with torch.cuda.stream(self.cuda_load_stream):
            phase = self.pin_memory_buffer.get(obj_key)
            phase.to_cuda_async()
            self.active_weights[2] = (obj_key, phase)

        with torch.cuda.stream(self.cpu_load_stream):
            if block_idx * self.phases_num + phase_idx < self.offload_phases_num:
                if self.active_weights[1] is not None:
                    old_key, old_phase = self.active_weights[1]
                    if self.pin_memory_buffer.exists(old_key):
                        old_phase.to_cpu_async()
                        self.pin_memory_buffer.pop(old_key)

    def shutdown(self):
        self.worker_stop_event.set()

        while not self.disk_task_queue.empty():
            try:
                self.disk_task_queue.get_nowait()
            except queue.Empty:
                continue

        for _ in self.disk_workers:
            self.disk_task_queue.put((0, None))

        for worker in self.disk_workers:
            worker.join(timeout=5)

        for worker in self.release_workers:
            worker.join(timeout=5)

        logger.info("All worker threads have been closed")

gushiqiao's avatar
gushiqiao committed
278
279
280
281
    def clear(self):
        self.pin_memory_buffer.clear()
        self.shutdown()

282
283
284
285
286
287

class MemoryBuffer:
    def __init__(self, max_memory_bytes=8 * (1024**3)):
        self.cache = OrderedDict()
        self.max_mem = max_memory_bytes
        self.used_mem = 0
gushiqiao's avatar
gushiqiao committed
288
        self.obj_size_map = {}
289
290
291
292
        self.lock = threading.Lock()
        self.insertion_order = []
        self.insertion_index = 0

gushiqiao's avatar
gushiqiao committed
293
    def push(self, key, obj):
294
295
296
        with self.lock:
            if key in self.cache:
                return
gushiqiao's avatar
gushiqiao committed
297
298
299
300
301
302
303
304
305
306
307
308
309
            if hasattr(obj, "compute_phases"):
                obj_idx = key
                if len(self.obj_size_map) == 0:
                    _size = 0
                    for phase in obj.compute_phases:
                        _size += phase.calculate_size()
                    self.obj_size_map[0] = _size
                size = self.obj_size_map[0]
            else:
                _, obj_idx = key
                if obj_idx not in self.obj_size_map:
                    self.obj_size_map[obj_idx] = obj.calculate_size()
                size = self.obj_size_map[obj_idx]
310

gushiqiao's avatar
gushiqiao committed
311
            self.cache[key] = (size, obj, self.insertion_index)
312
313
314
315
316
317
            self.insertion_order.append((key, self.insertion_index))
            self.insertion_index += 1
            self.used_mem += size

    def _remove_key(self, key):
        if key in self.cache:
gushiqiao's avatar
gushiqiao committed
318
            size, obj, idx = self.cache.pop(key)
319
            try:
gushiqiao's avatar
gushiqiao committed
320
321
322
323
324
                if hasattr(obj, "compute_phases"):
                    for phase in obj.compute_phases:
                        phase.clear()
                else:
                    obj.clear()
325
            except Exception as e:
gushiqiao's avatar
gushiqiao committed
326
                logger.info(f"Error clearing obj: {e}")
327
328
329
330
331
332
333
            self.used_mem -= size

            self.insertion_order = [(k, i) for (k, i) in self.insertion_order if k != key]

    def get(self, key, default=None):
        with self.lock:
            if key in self.cache:
gushiqiao's avatar
gushiqiao committed
334
335
                size, obj, idx = self.cache[key]
                return obj
336
337
338
339
340
341
        return default

    def exists(self, key):
        with self.lock:
            return key in self.cache

gushiqiao's avatar
gushiqiao committed
342
343
344
345
346
347
348
349
    def pop_front(self):
        with self.lock:
            if not self.insertion_order:
                return False
            front_key, _ = self.insertion_order[0]
            self._remove_key(front_key)
            return True

350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
    def pop(self, key):
        with self.lock:
            if key in self.cache:
                self._remove_key(key)
                return True
        return False

    def is_nearly_full(self):
        with self.lock:
            return self.used_mem >= self.max_mem * 0.9

    def get_max_block_index(self):
        with self.lock:
            if not self.cache:
                return -1
gushiqiao's avatar
gushiqiao committed
365
366
367
368
            if isinstance(list(self.cache.keys())[-1], tuple):
                return (list(self.cache.keys())[-1][0] + 1) % 40
            else:
                return (list(self.cache.keys())[-1] + 1) % 40
gushiqiao's avatar
gushiqiao committed
369
370
371
372
373
374
375
376
377
378

    def clear(self):
        with self.lock:
            for key in list(self.cache.keys()):
                self._remove_key(key)

            self.insertion_order = []
            self.insertion_index = 0
            self.used_mem = 0
            torch.cuda.empty_cache()
gushiqiao's avatar
gushiqiao committed
379
            gc.collect()