manager.py 14.5 KB
Newer Older
PengGao's avatar
PengGao committed
1
import gc
2
import queue
PengGao's avatar
PengGao committed
3
import threading
4
5
import time
from collections import OrderedDict
gushiqiao's avatar
gushiqiao committed
6

PengGao's avatar
PengGao committed
7
8
9
import torch
from loguru import logger

gushiqiao's avatar
gushiqiao committed
10

11
class WeightAsyncStreamManager(object):
12
13
14
    def __init__(self, offload_granularity):
        self.offload_granularity = offload_granularity
        self.init_stream = torch.cuda.Stream(priority=0)
15
16
        self.cuda_load_stream = torch.cuda.Stream(priority=1)
        self.compute_stream = torch.cuda.Stream(priority=1)
17
18
19
20
21
22
23
24
25
26

    def init_cuda_buffer(self, blocks_cuda_buffer=None, phases_cuda_buffer=None):
        if self.offload_granularity == "block":
            assert blocks_cuda_buffer is not None
            self.cuda_buffers = [blocks_cuda_buffer[i] for i in range(len(blocks_cuda_buffer))]
        elif self.offload_granularity == "phase":
            assert phases_cuda_buffer is not None
            self.cuda_buffers = [phases_cuda_buffer[i] for i in range(len(phases_cuda_buffer))]
        else:
            raise NotImplementedError
27

28
29
30
31
32
33
34
35
36
37
    def init_first_buffer(self, blocks, adapter_block_idx=None):
        if self.offload_granularity == "block":
            with torch.cuda.stream(self.init_stream):
                self.cuda_buffers[0].load_state_dict(blocks[0].state_dict(), 0, adapter_block_idx)
        else:
            with torch.cuda.stream(self.init_stream):
                self.cuda_buffers[0].load_state_dict(blocks[0].compute_phases[0].state_dict(), 0, adapter_block_idx)
        self.init_stream.synchronize()

    def prefetch_weights(self, block_idx, blocks, adapter_block_idx=None):
38
        with torch.cuda.stream(self.cuda_load_stream):
39
            self.cuda_buffers[1].load_state_dict(blocks[block_idx].state_dict(), block_idx, adapter_block_idx)
gushiqiao's avatar
gushiqiao committed
40

41
    def swap_blocks(self):
42
        self.cuda_load_stream.synchronize()
43
44
45
46
        self.compute_stream.synchronize()
        self.cuda_buffers[0], self.cuda_buffers[1] = (
            self.cuda_buffers[1],
            self.cuda_buffers[0],
gushiqiao's avatar
gushiqiao committed
47
        )
48

49
    def prefetch_phase(self, block_idx, phase_idx, blocks, adapter_block_idx=None):
50
        with torch.cuda.stream(self.cuda_load_stream):
51
            self.cuda_buffers[phase_idx].load_state_dict(blocks[block_idx].compute_phases[phase_idx].state_dict(), block_idx, adapter_block_idx)
52
53

    def swap_phases(self):
54
        self.cuda_load_stream.synchronize()
55
        self.compute_stream.synchronize()
56
57
58


class LazyWeightAsyncStreamManager(WeightAsyncStreamManager):
59
60
61
62
63
64
65
66
67
    def __init__(
        self,
        blocks_num,
        offload_ratio=1,
        phases_num=1,
        num_disk_workers=1,
        max_memory=2,
        offload_gra="phase",
    ):
68
        super().__init__(blocks_num, offload_ratio, phases_num)
gushiqiao's avatar
gushiqiao committed
69
        self.offload_gra = offload_gra
70
71
72
73
74
75
76
77
78
79
80
81
82
83
        self.worker_stop_event = threading.Event()
        self.pin_memory_buffer = MemoryBuffer(max_memory * (1024**3))
        self.disk_task_queue = queue.PriorityQueue()
        self.disk_workers = []
        self.release_workers = []
        self._start_disk_workers(num_disk_workers)
        self.initial_prefetch_done = False
        self.pending_tasks = {}
        self.task_lock = threading.Lock()
        self.last_used_time = {}
        self.time_lock = threading.Lock()

    def _start_disk_workers(self, num_workers):
        for i in range(num_workers):
gushiqiao's avatar
gushiqiao committed
84
85
86
87
            if self.offload_gra == "phase":
                worker = threading.Thread(target=self._disk_worker_loop, daemon=True)
            else:
                worker = threading.Thread(target=self._disk_worker_loop_block, daemon=True)
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
            worker.start()
            self.disk_workers.append(worker)

    def _disk_worker_loop(self):
        while not self.worker_stop_event.is_set():
            try:
                _, task = self.disk_task_queue.get(timeout=0.5)
                if task is None:
                    break

                block_idx, phase_idx, phase = task

                phase.load_from_disk()
                self.pin_memory_buffer.push((block_idx, phase_idx), phase)

                with self.task_lock:
                    if (block_idx, phase_idx) in self.pending_tasks:
                        del self.pending_tasks[(block_idx, phase_idx)]
            except queue.Empty:
                continue
            except Exception as e:
                logger.error(f"Disk worker thread error: {e}")

gushiqiao's avatar
gushiqiao committed
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
    def _disk_worker_loop_block(self):
        while not self.worker_stop_event.is_set():
            try:
                _, task = self.disk_task_queue.get(timeout=0.5)
                if task is None:
                    break

                block_idx, block = task

                for phase in block.compute_phases:
                    phase.load_from_disk()
                self.pin_memory_buffer.push(block_idx, block)

                with self.task_lock:
                    if block_idx in self.pending_tasks:
                        del self.pending_tasks[block_idx]
            except queue.Empty:
                continue
            except Exception as e:
                logger.error(f"Disk worker thread error: {e}")

gushiqiao's avatar
gushiqiao committed
132
133
134
    def _async_prefetch_block(self, blocks, next_block_idx=None):
        if next_block_idx is None:
            next_block_idx = self.pin_memory_buffer.get_max_block_index()
gushiqiao's avatar
gushiqiao committed
135

136
137
138
        if next_block_idx < 0:
            next_block_idx = 0

139
        if next_block_idx == self.blocks_num:
gushiqiao's avatar
gushiqiao committed
140
141
            return

gushiqiao's avatar
gushiqiao committed
142
143
144
145
146
147
148
149
150
151
        if self.offload_gra == "phase":
            for phase_idx in range(self.phases_num):
                obj_key = (next_block_idx, phase_idx)

                if self.pin_memory_buffer.exists(obj_key) or (obj_key in self.pending_tasks):
                    continue

                with self.task_lock:
                    self.pending_tasks[obj_key] = True

gushiqiao's avatar
gushiqiao committed
152
                phase = blocks[next_block_idx].compute_phases[phase_idx]
153

gushiqiao's avatar
gushiqiao committed
154
155
156
157
                priority_key = (next_block_idx, phase_idx)
                self.disk_task_queue.put((priority_key, (next_block_idx, phase_idx, phase)))
        else:
            obj_key = next_block_idx
158
            if self.pin_memory_buffer.exists(obj_key) or (obj_key in self.pending_tasks):
gushiqiao's avatar
gushiqiao committed
159
                return
160
161
162
163

            with self.task_lock:
                self.pending_tasks[obj_key] = True

gushiqiao's avatar
gushiqiao committed
164
            block = blocks[next_block_idx]
gushiqiao's avatar
gushiqiao committed
165
            self.disk_task_queue.put((obj_key, (next_block_idx, block)))
166

gushiqiao's avatar
gushiqiao committed
167
    def _sync_prefetch_block(self, blocks):
168
169
        block_idx = 0
        while not self.pin_memory_buffer.is_nearly_full():
gushiqiao's avatar
gushiqiao committed
170
171
            if self.offload_gra == "phase":
                for phase_idx in range(self.phases_num):
gushiqiao's avatar
gushiqiao committed
172
                    phase = blocks[block_idx].compute_phases[phase_idx]
gushiqiao's avatar
gushiqiao committed
173
174
175
176
                    logger.info(f"Synchronous loading: block={block_idx}, phase={phase_idx}")
                    phase.load_from_disk()
                    self.pin_memory_buffer.push((block_idx, phase_idx), phase)
            else:
gushiqiao's avatar
gushiqiao committed
177
                block = blocks[block_idx]
gushiqiao's avatar
gushiqiao committed
178
179
180
181
182
                logger.info(f"Synchronous loading: block={block_idx}")
                for phase in block.compute_phases:
                    phase.load_from_disk()
                self.pin_memory_buffer.push(block_idx, block)

183
            block_idx += 1
184
            if block_idx == self.blocks_num:
gushiqiao's avatar
gushiqiao committed
185
                break
186

gushiqiao's avatar
gushiqiao committed
187
    def prefetch_weights_from_disk(self, blocks):
188
189
190
        if self.initial_prefetch_done:
            return

gushiqiao's avatar
gushiqiao committed
191
        self._sync_prefetch_block(blocks)
192
193
        self.initial_prefetch_done = True

gushiqiao's avatar
gushiqiao committed
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
    def prefetch_weights(self, block_idx, blocks):
        obj_key = block_idx

        if not self.pin_memory_buffer.exists(obj_key):
            is_loading = False
            with self.task_lock:
                if obj_key in self.pending_tasks:
                    is_loading = True

            if is_loading:
                start_time = time.time()
                while not self.pin_memory_buffer.exists(obj_key):
                    time.sleep(0.001)
                    if time.time() - start_time > 5:
                        raise TimeoutError(f"Load timeout: block={block_idx}")
            else:
gushiqiao's avatar
gushiqiao committed
210
211
212
213
214
215
216
217
218
                logger.info("Not find prefetch block={block_idx} task.")
                logger.info("Sync prefetch block={block_idx}.")
                self._async_prefetch_block(blocks, block_idx)
                start_time = time.time()
                for phase_idx in self.phases_num:
                    while not self.pin_memory_buffer.exists((block_idx, phase_idx)):
                        time.sleep(0.001)
                        if time.time() - start_time > 15:
                            raise TimeoutError(f"Load timeout: block={block_idx}, phase={phase_idx}")
gushiqiao's avatar
gushiqiao committed
219
220
221
222

        with torch.cuda.stream(self.cuda_load_stream):
            block = self.pin_memory_buffer.get(obj_key)
            block.to_cuda_async()
223
            self.cuda_buffers[2] = (obj_key, block)
gushiqiao's avatar
gushiqiao committed
224
225

        with torch.cuda.stream(self.cpu_load_stream):
226
            if block_idx < self.offload_blocks_num:
227
228
                if self.cuda_buffers[1] is not None:
                    old_key, old_block = self.cuda_buffers[1]
gushiqiao's avatar
gushiqiao committed
229
230
231
232
                    if self.pin_memory_buffer.exists(old_key):
                        old_block.to_cpu_async()
                        self.pin_memory_buffer.pop(old_key)

233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
    def prefetch_phase(self, block_idx, phase_idx, blocks):
        obj_key = (block_idx, phase_idx)

        if not self.pin_memory_buffer.exists(obj_key):
            is_loading = False
            with self.task_lock:
                if obj_key in self.pending_tasks:
                    is_loading = True

            if is_loading:
                start_time = time.time()
                while not self.pin_memory_buffer.exists(obj_key):
                    time.sleep(0.001)
                    if time.time() - start_time > 5:
                        raise TimeoutError(f"Load timeout: block={block_idx}, phase={phase_idx}")
            else:
gushiqiao's avatar
gushiqiao committed
249
250
251
252
253
254
255
256
                logger.info(f"Not find block={block_idx}, phase={phase_idx} task.")
                logger.info(f"Sync prefetch block={block_idx}, phase={phase_idx}.")
                self._async_prefetch_block(blocks, block_idx)
                start_time = time.time()
                while not self.pin_memory_buffer.exists((block_idx, phase_idx)):
                    time.sleep(0.001)
                    if time.time() - start_time > 5:
                        raise TimeoutError(f"Load timeout: block={block_idx}, phase={phase_idx}")
257
258
259
260

        with torch.cuda.stream(self.cuda_load_stream):
            phase = self.pin_memory_buffer.get(obj_key)
            phase.to_cuda_async()
261
            self.cuda_buffers[2] = (obj_key, phase)
262
263
264

        with torch.cuda.stream(self.cpu_load_stream):
            if block_idx * self.phases_num + phase_idx < self.offload_phases_num:
265
266
                if self.cuda_buffers[1] is not None:
                    old_key, old_phase = self.cuda_buffers[1]
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
                    if self.pin_memory_buffer.exists(old_key):
                        old_phase.to_cpu_async()
                        self.pin_memory_buffer.pop(old_key)

    def shutdown(self):
        self.worker_stop_event.set()

        while not self.disk_task_queue.empty():
            try:
                self.disk_task_queue.get_nowait()
            except queue.Empty:
                continue

        for _ in self.disk_workers:
            self.disk_task_queue.put((0, None))

        for worker in self.disk_workers:
            worker.join(timeout=5)

        for worker in self.release_workers:
            worker.join(timeout=5)

        logger.info("All worker threads have been closed")

gushiqiao's avatar
gushiqiao committed
291
292
293
294
    def clear(self):
        self.pin_memory_buffer.clear()
        self.shutdown()

295
296
297
298
299
300

class MemoryBuffer:
    def __init__(self, max_memory_bytes=8 * (1024**3)):
        self.cache = OrderedDict()
        self.max_mem = max_memory_bytes
        self.used_mem = 0
gushiqiao's avatar
gushiqiao committed
301
        self.obj_size_map = {}
302
303
304
305
        self.lock = threading.Lock()
        self.insertion_order = []
        self.insertion_index = 0

gushiqiao's avatar
gushiqiao committed
306
    def push(self, key, obj):
307
308
309
        with self.lock:
            if key in self.cache:
                return
gushiqiao's avatar
gushiqiao committed
310
311
312
313
314
315
316
317
318
319
320
321
322
            if hasattr(obj, "compute_phases"):
                obj_idx = key
                if len(self.obj_size_map) == 0:
                    _size = 0
                    for phase in obj.compute_phases:
                        _size += phase.calculate_size()
                    self.obj_size_map[0] = _size
                size = self.obj_size_map[0]
            else:
                _, obj_idx = key
                if obj_idx not in self.obj_size_map:
                    self.obj_size_map[obj_idx] = obj.calculate_size()
                size = self.obj_size_map[obj_idx]
323

gushiqiao's avatar
gushiqiao committed
324
            self.cache[key] = (size, obj, self.insertion_index)
325
326
327
328
329
330
            self.insertion_order.append((key, self.insertion_index))
            self.insertion_index += 1
            self.used_mem += size

    def _remove_key(self, key):
        if key in self.cache:
gushiqiao's avatar
gushiqiao committed
331
            size, obj, idx = self.cache.pop(key)
332
            try:
gushiqiao's avatar
gushiqiao committed
333
334
335
336
337
                if hasattr(obj, "compute_phases"):
                    for phase in obj.compute_phases:
                        phase.clear()
                else:
                    obj.clear()
338
            except Exception as e:
gushiqiao's avatar
gushiqiao committed
339
                logger.info(f"Error clearing obj: {e}")
340
341
342
343
344
345
346
            self.used_mem -= size

            self.insertion_order = [(k, i) for (k, i) in self.insertion_order if k != key]

    def get(self, key, default=None):
        with self.lock:
            if key in self.cache:
gushiqiao's avatar
gushiqiao committed
347
348
                size, obj, idx = self.cache[key]
                return obj
349
350
351
352
353
354
        return default

    def exists(self, key):
        with self.lock:
            return key in self.cache

gushiqiao's avatar
gushiqiao committed
355
356
357
358
359
360
361
362
    def pop_front(self):
        with self.lock:
            if not self.insertion_order:
                return False
            front_key, _ = self.insertion_order[0]
            self._remove_key(front_key)
            return True

363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
    def pop(self, key):
        with self.lock:
            if key in self.cache:
                self._remove_key(key)
                return True
        return False

    def is_nearly_full(self):
        with self.lock:
            return self.used_mem >= self.max_mem * 0.9

    def get_max_block_index(self):
        with self.lock:
            if not self.cache:
                return -1
gushiqiao's avatar
gushiqiao committed
378
379
380
381
            if isinstance(list(self.cache.keys())[-1], tuple):
                return (list(self.cache.keys())[-1][0] + 1) % 40
            else:
                return (list(self.cache.keys())[-1] + 1) % 40
gushiqiao's avatar
gushiqiao committed
382
383
384
385
386
387
388
389
390
391

    def clear(self):
        with self.lock:
            for key in list(self.cache.keys()):
                self._remove_key(key)

            self.insertion_order = []
            self.insertion_index = 0
            self.used_mem = 0
            torch.cuda.empty_cache()
gushiqiao's avatar
gushiqiao committed
392
            gc.collect()