manager.py 14.7 KB
Newer Older
PengGao's avatar
PengGao committed
1
import gc
2
import queue
PengGao's avatar
PengGao committed
3
import threading
4
5
import time
from collections import OrderedDict
gushiqiao's avatar
gushiqiao committed
6

PengGao's avatar
PengGao committed
7
8
import torch
from loguru import logger
Gu Shiqiao's avatar
Gu Shiqiao committed
9
from packaging.version import parse
PengGao's avatar
PengGao committed
10

gushiqiao's avatar
gushiqiao committed
11

12
class WeightAsyncStreamManager(object):
13
14
15
    def __init__(self, offload_granularity):
        self.offload_granularity = offload_granularity
        self.init_stream = torch.cuda.Stream(priority=0)
Gu Shiqiao's avatar
Gu Shiqiao committed
16
        torch_version = parse(torch.__version__.split("+")[0])
17
        if torch_version >= parse("2.7"):
Gu Shiqiao's avatar
Gu Shiqiao committed
18
19
20
21
22
            self.cuda_load_stream = torch.cuda.Stream(priority=1)
            self.compute_stream = torch.cuda.Stream(priority=1)
        else:
            self.cuda_load_stream = torch.cuda.Stream(priority=0)
            self.compute_stream = torch.cuda.Stream(priority=-1)
23
24
25
26
27
28
29
30
31
32

    def init_cuda_buffer(self, blocks_cuda_buffer=None, phases_cuda_buffer=None):
        if self.offload_granularity == "block":
            assert blocks_cuda_buffer is not None
            self.cuda_buffers = [blocks_cuda_buffer[i] for i in range(len(blocks_cuda_buffer))]
        elif self.offload_granularity == "phase":
            assert phases_cuda_buffer is not None
            self.cuda_buffers = [phases_cuda_buffer[i] for i in range(len(phases_cuda_buffer))]
        else:
            raise NotImplementedError
33

34
35
36
37
38
39
40
41
42
43
    def init_first_buffer(self, blocks, adapter_block_idx=None):
        if self.offload_granularity == "block":
            with torch.cuda.stream(self.init_stream):
                self.cuda_buffers[0].load_state_dict(blocks[0].state_dict(), 0, adapter_block_idx)
        else:
            with torch.cuda.stream(self.init_stream):
                self.cuda_buffers[0].load_state_dict(blocks[0].compute_phases[0].state_dict(), 0, adapter_block_idx)
        self.init_stream.synchronize()

    def prefetch_weights(self, block_idx, blocks, adapter_block_idx=None):
44
        with torch.cuda.stream(self.cuda_load_stream):
45
            self.cuda_buffers[1].load_state_dict(blocks[block_idx].state_dict(), block_idx, adapter_block_idx)
gushiqiao's avatar
gushiqiao committed
46

47
    def swap_blocks(self):
48
        self.cuda_load_stream.synchronize()
49
50
51
52
        self.compute_stream.synchronize()
        self.cuda_buffers[0], self.cuda_buffers[1] = (
            self.cuda_buffers[1],
            self.cuda_buffers[0],
gushiqiao's avatar
gushiqiao committed
53
        )
54

55
    def prefetch_phase(self, block_idx, phase_idx, blocks, adapter_block_idx=None):
56
        with torch.cuda.stream(self.cuda_load_stream):
57
            self.cuda_buffers[phase_idx].load_state_dict(blocks[block_idx].compute_phases[phase_idx].state_dict(), block_idx, adapter_block_idx)
58
59

    def swap_phases(self):
60
        self.cuda_load_stream.synchronize()
61
        self.compute_stream.synchronize()
62
63
64


class LazyWeightAsyncStreamManager(WeightAsyncStreamManager):
65
66
67
68
69
70
71
72
73
    def __init__(
        self,
        blocks_num,
        offload_ratio=1,
        phases_num=1,
        num_disk_workers=1,
        max_memory=2,
        offload_gra="phase",
    ):
74
        super().__init__(blocks_num, offload_ratio, phases_num)
gushiqiao's avatar
gushiqiao committed
75
        self.offload_gra = offload_gra
76
77
78
79
80
81
82
83
84
85
86
87
88
89
        self.worker_stop_event = threading.Event()
        self.pin_memory_buffer = MemoryBuffer(max_memory * (1024**3))
        self.disk_task_queue = queue.PriorityQueue()
        self.disk_workers = []
        self.release_workers = []
        self._start_disk_workers(num_disk_workers)
        self.initial_prefetch_done = False
        self.pending_tasks = {}
        self.task_lock = threading.Lock()
        self.last_used_time = {}
        self.time_lock = threading.Lock()

    def _start_disk_workers(self, num_workers):
        for i in range(num_workers):
gushiqiao's avatar
gushiqiao committed
90
91
92
93
            if self.offload_gra == "phase":
                worker = threading.Thread(target=self._disk_worker_loop, daemon=True)
            else:
                worker = threading.Thread(target=self._disk_worker_loop_block, daemon=True)
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
            worker.start()
            self.disk_workers.append(worker)

    def _disk_worker_loop(self):
        while not self.worker_stop_event.is_set():
            try:
                _, task = self.disk_task_queue.get(timeout=0.5)
                if task is None:
                    break

                block_idx, phase_idx, phase = task

                phase.load_from_disk()
                self.pin_memory_buffer.push((block_idx, phase_idx), phase)

                with self.task_lock:
                    if (block_idx, phase_idx) in self.pending_tasks:
                        del self.pending_tasks[(block_idx, phase_idx)]
            except queue.Empty:
                continue
            except Exception as e:
                logger.error(f"Disk worker thread error: {e}")

gushiqiao's avatar
gushiqiao committed
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
    def _disk_worker_loop_block(self):
        while not self.worker_stop_event.is_set():
            try:
                _, task = self.disk_task_queue.get(timeout=0.5)
                if task is None:
                    break

                block_idx, block = task

                for phase in block.compute_phases:
                    phase.load_from_disk()
                self.pin_memory_buffer.push(block_idx, block)

                with self.task_lock:
                    if block_idx in self.pending_tasks:
                        del self.pending_tasks[block_idx]
            except queue.Empty:
                continue
            except Exception as e:
                logger.error(f"Disk worker thread error: {e}")

gushiqiao's avatar
gushiqiao committed
138
139
140
    def _async_prefetch_block(self, blocks, next_block_idx=None):
        if next_block_idx is None:
            next_block_idx = self.pin_memory_buffer.get_max_block_index()
gushiqiao's avatar
gushiqiao committed
141

142
143
144
        if next_block_idx < 0:
            next_block_idx = 0

145
        if next_block_idx == self.blocks_num:
gushiqiao's avatar
gushiqiao committed
146
147
            return

gushiqiao's avatar
gushiqiao committed
148
149
150
151
152
153
154
155
156
157
        if self.offload_gra == "phase":
            for phase_idx in range(self.phases_num):
                obj_key = (next_block_idx, phase_idx)

                if self.pin_memory_buffer.exists(obj_key) or (obj_key in self.pending_tasks):
                    continue

                with self.task_lock:
                    self.pending_tasks[obj_key] = True

gushiqiao's avatar
gushiqiao committed
158
                phase = blocks[next_block_idx].compute_phases[phase_idx]
159

gushiqiao's avatar
gushiqiao committed
160
161
162
163
                priority_key = (next_block_idx, phase_idx)
                self.disk_task_queue.put((priority_key, (next_block_idx, phase_idx, phase)))
        else:
            obj_key = next_block_idx
164
            if self.pin_memory_buffer.exists(obj_key) or (obj_key in self.pending_tasks):
gushiqiao's avatar
gushiqiao committed
165
                return
166
167
168
169

            with self.task_lock:
                self.pending_tasks[obj_key] = True

gushiqiao's avatar
gushiqiao committed
170
            block = blocks[next_block_idx]
gushiqiao's avatar
gushiqiao committed
171
            self.disk_task_queue.put((obj_key, (next_block_idx, block)))
172

gushiqiao's avatar
gushiqiao committed
173
    def _sync_prefetch_block(self, blocks):
174
175
        block_idx = 0
        while not self.pin_memory_buffer.is_nearly_full():
gushiqiao's avatar
gushiqiao committed
176
177
            if self.offload_gra == "phase":
                for phase_idx in range(self.phases_num):
gushiqiao's avatar
gushiqiao committed
178
                    phase = blocks[block_idx].compute_phases[phase_idx]
gushiqiao's avatar
gushiqiao committed
179
180
181
182
                    logger.info(f"Synchronous loading: block={block_idx}, phase={phase_idx}")
                    phase.load_from_disk()
                    self.pin_memory_buffer.push((block_idx, phase_idx), phase)
            else:
gushiqiao's avatar
gushiqiao committed
183
                block = blocks[block_idx]
gushiqiao's avatar
gushiqiao committed
184
185
186
187
188
                logger.info(f"Synchronous loading: block={block_idx}")
                for phase in block.compute_phases:
                    phase.load_from_disk()
                self.pin_memory_buffer.push(block_idx, block)

189
            block_idx += 1
190
            if block_idx == self.blocks_num:
gushiqiao's avatar
gushiqiao committed
191
                break
192

gushiqiao's avatar
gushiqiao committed
193
    def prefetch_weights_from_disk(self, blocks):
194
195
196
        if self.initial_prefetch_done:
            return

gushiqiao's avatar
gushiqiao committed
197
        self._sync_prefetch_block(blocks)
198
199
        self.initial_prefetch_done = True

gushiqiao's avatar
gushiqiao committed
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
    def prefetch_weights(self, block_idx, blocks):
        obj_key = block_idx

        if not self.pin_memory_buffer.exists(obj_key):
            is_loading = False
            with self.task_lock:
                if obj_key in self.pending_tasks:
                    is_loading = True

            if is_loading:
                start_time = time.time()
                while not self.pin_memory_buffer.exists(obj_key):
                    time.sleep(0.001)
                    if time.time() - start_time > 5:
                        raise TimeoutError(f"Load timeout: block={block_idx}")
            else:
gushiqiao's avatar
gushiqiao committed
216
217
218
219
220
221
222
223
224
                logger.info("Not find prefetch block={block_idx} task.")
                logger.info("Sync prefetch block={block_idx}.")
                self._async_prefetch_block(blocks, block_idx)
                start_time = time.time()
                for phase_idx in self.phases_num:
                    while not self.pin_memory_buffer.exists((block_idx, phase_idx)):
                        time.sleep(0.001)
                        if time.time() - start_time > 15:
                            raise TimeoutError(f"Load timeout: block={block_idx}, phase={phase_idx}")
gushiqiao's avatar
gushiqiao committed
225
226
227
228

        with torch.cuda.stream(self.cuda_load_stream):
            block = self.pin_memory_buffer.get(obj_key)
            block.to_cuda_async()
229
            self.cuda_buffers[2] = (obj_key, block)
gushiqiao's avatar
gushiqiao committed
230
231

        with torch.cuda.stream(self.cpu_load_stream):
232
            if block_idx < self.offload_blocks_num:
233
234
                if self.cuda_buffers[1] is not None:
                    old_key, old_block = self.cuda_buffers[1]
gushiqiao's avatar
gushiqiao committed
235
236
237
238
                    if self.pin_memory_buffer.exists(old_key):
                        old_block.to_cpu_async()
                        self.pin_memory_buffer.pop(old_key)

239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
    def prefetch_phase(self, block_idx, phase_idx, blocks):
        obj_key = (block_idx, phase_idx)

        if not self.pin_memory_buffer.exists(obj_key):
            is_loading = False
            with self.task_lock:
                if obj_key in self.pending_tasks:
                    is_loading = True

            if is_loading:
                start_time = time.time()
                while not self.pin_memory_buffer.exists(obj_key):
                    time.sleep(0.001)
                    if time.time() - start_time > 5:
                        raise TimeoutError(f"Load timeout: block={block_idx}, phase={phase_idx}")
            else:
gushiqiao's avatar
gushiqiao committed
255
256
257
258
259
260
261
262
                logger.info(f"Not find block={block_idx}, phase={phase_idx} task.")
                logger.info(f"Sync prefetch block={block_idx}, phase={phase_idx}.")
                self._async_prefetch_block(blocks, block_idx)
                start_time = time.time()
                while not self.pin_memory_buffer.exists((block_idx, phase_idx)):
                    time.sleep(0.001)
                    if time.time() - start_time > 5:
                        raise TimeoutError(f"Load timeout: block={block_idx}, phase={phase_idx}")
263
264
265
266

        with torch.cuda.stream(self.cuda_load_stream):
            phase = self.pin_memory_buffer.get(obj_key)
            phase.to_cuda_async()
267
            self.cuda_buffers[2] = (obj_key, phase)
268
269
270

        with torch.cuda.stream(self.cpu_load_stream):
            if block_idx * self.phases_num + phase_idx < self.offload_phases_num:
271
272
                if self.cuda_buffers[1] is not None:
                    old_key, old_phase = self.cuda_buffers[1]
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
                    if self.pin_memory_buffer.exists(old_key):
                        old_phase.to_cpu_async()
                        self.pin_memory_buffer.pop(old_key)

    def shutdown(self):
        self.worker_stop_event.set()

        while not self.disk_task_queue.empty():
            try:
                self.disk_task_queue.get_nowait()
            except queue.Empty:
                continue

        for _ in self.disk_workers:
            self.disk_task_queue.put((0, None))

        for worker in self.disk_workers:
            worker.join(timeout=5)

        for worker in self.release_workers:
            worker.join(timeout=5)

        logger.info("All worker threads have been closed")

gushiqiao's avatar
gushiqiao committed
297
298
299
300
    def clear(self):
        self.pin_memory_buffer.clear()
        self.shutdown()

301
302
303
304
305
306

class MemoryBuffer:
    def __init__(self, max_memory_bytes=8 * (1024**3)):
        self.cache = OrderedDict()
        self.max_mem = max_memory_bytes
        self.used_mem = 0
gushiqiao's avatar
gushiqiao committed
307
        self.obj_size_map = {}
308
309
310
311
        self.lock = threading.Lock()
        self.insertion_order = []
        self.insertion_index = 0

gushiqiao's avatar
gushiqiao committed
312
    def push(self, key, obj):
313
314
315
        with self.lock:
            if key in self.cache:
                return
gushiqiao's avatar
gushiqiao committed
316
317
318
319
320
321
322
323
324
325
326
327
328
            if hasattr(obj, "compute_phases"):
                obj_idx = key
                if len(self.obj_size_map) == 0:
                    _size = 0
                    for phase in obj.compute_phases:
                        _size += phase.calculate_size()
                    self.obj_size_map[0] = _size
                size = self.obj_size_map[0]
            else:
                _, obj_idx = key
                if obj_idx not in self.obj_size_map:
                    self.obj_size_map[obj_idx] = obj.calculate_size()
                size = self.obj_size_map[obj_idx]
329

gushiqiao's avatar
gushiqiao committed
330
            self.cache[key] = (size, obj, self.insertion_index)
331
332
333
334
335
336
            self.insertion_order.append((key, self.insertion_index))
            self.insertion_index += 1
            self.used_mem += size

    def _remove_key(self, key):
        if key in self.cache:
gushiqiao's avatar
gushiqiao committed
337
            size, obj, idx = self.cache.pop(key)
338
            try:
gushiqiao's avatar
gushiqiao committed
339
340
341
342
343
                if hasattr(obj, "compute_phases"):
                    for phase in obj.compute_phases:
                        phase.clear()
                else:
                    obj.clear()
344
            except Exception as e:
gushiqiao's avatar
gushiqiao committed
345
                logger.info(f"Error clearing obj: {e}")
346
347
348
349
350
351
352
            self.used_mem -= size

            self.insertion_order = [(k, i) for (k, i) in self.insertion_order if k != key]

    def get(self, key, default=None):
        with self.lock:
            if key in self.cache:
gushiqiao's avatar
gushiqiao committed
353
354
                size, obj, idx = self.cache[key]
                return obj
355
356
357
358
359
360
        return default

    def exists(self, key):
        with self.lock:
            return key in self.cache

gushiqiao's avatar
gushiqiao committed
361
362
363
364
365
366
367
368
    def pop_front(self):
        with self.lock:
            if not self.insertion_order:
                return False
            front_key, _ = self.insertion_order[0]
            self._remove_key(front_key)
            return True

369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
    def pop(self, key):
        with self.lock:
            if key in self.cache:
                self._remove_key(key)
                return True
        return False

    def is_nearly_full(self):
        with self.lock:
            return self.used_mem >= self.max_mem * 0.9

    def get_max_block_index(self):
        with self.lock:
            if not self.cache:
                return -1
gushiqiao's avatar
gushiqiao committed
384
385
386
387
            if isinstance(list(self.cache.keys())[-1], tuple):
                return (list(self.cache.keys())[-1][0] + 1) % 40
            else:
                return (list(self.cache.keys())[-1] + 1) % 40
gushiqiao's avatar
gushiqiao committed
388
389
390
391
392
393
394
395
396
397

    def clear(self):
        with self.lock:
            for key in list(self.cache.keys()):
                self._remove_key(key)

            self.insertion_order = []
            self.insertion_index = 0
            self.used_mem = 0
            torch.cuda.empty_cache()
gushiqiao's avatar
gushiqiao committed
398
            gc.collect()