manager.py 6.1 KB
Newer Older
Gu Shiqiao's avatar
Gu Shiqiao committed
1
import time
2
3
from concurrent.futures import ThreadPoolExecutor

PengGao's avatar
PengGao committed
4
import torch
5
from loguru import logger
Gu Shiqiao's avatar
Gu Shiqiao committed
6
from packaging.version import parse
7
from tqdm import tqdm
PengGao's avatar
PengGao committed
8

9
from lightx2v.utils.profiler import ExcludedProfilingContext
10
11
12
13
from lightx2v_platform.base.global_var import AI_DEVICE

torch_device_module = getattr(torch, AI_DEVICE)

gushiqiao's avatar
gushiqiao committed
14

15
class WeightAsyncStreamManager(object):
16
17
    def __init__(self, offload_granularity):
        self.offload_granularity = offload_granularity
18
19
        self.init_stream = torch_device_module.Stream(priority=0)
        self.need_init_first_buffer = True
20
        self.lazy_load = False
Gu Shiqiao's avatar
Gu Shiqiao committed
21
        torch_version = parse(torch.__version__.split("+")[0])
22
23
24
        if AI_DEVICE == "cuda" and torch_version >= parse("2.7"):
            self.cuda_load_stream = torch_device_module.Stream(priority=1)
            self.compute_stream = torch_device_module.Stream(priority=1)
Gu Shiqiao's avatar
Gu Shiqiao committed
25
        else:
26
27
28
29
30
31
32
33
34
35
36
37
38
            self.cuda_load_stream = torch_device_module.Stream(priority=0)
            self.compute_stream = torch_device_module.Stream(priority=-1)

    def init_cpu_buffer(self, blocks_cpu_buffer=None, phases_cpu_buffer=None):
        self.need_init_first_buffer = True
        if self.offload_granularity == "block":
            assert blocks_cpu_buffer is not None
            self.cpu_buffers = [blocks_cpu_buffer[i] for i in range(len(blocks_cpu_buffer))]
        elif self.offload_granularity == "phase":
            assert phases_cpu_buffer is not None
            self.cpu_buffers = [phases_cpu_buffer[i] for i in range(len(phases_cpu_buffer))]
        else:
            raise NotImplementedError
39
40

    def init_cuda_buffer(self, blocks_cuda_buffer=None, phases_cuda_buffer=None):
41
        self.need_init_first_buffer = True
42
43
44
45
46
47
48
49
        if self.offload_granularity == "block":
            assert blocks_cuda_buffer is not None
            self.cuda_buffers = [blocks_cuda_buffer[i] for i in range(len(blocks_cuda_buffer))]
        elif self.offload_granularity == "phase":
            assert phases_cuda_buffer is not None
            self.cuda_buffers = [phases_cuda_buffer[i] for i in range(len(phases_cuda_buffer))]
        else:
            raise NotImplementedError
50

51
    def init_first_buffer(self, blocks, adapter_block_idx=None):
52
53
        with torch_device_module.stream(self.init_stream):
            if hasattr(self, "cpu_buffers"):
54
                self.cuda_buffers[0].load_state_dict(self.cpu_buffers[0][0].state_dict(), 0, adapter_block_idx)
55
56
57
58
59
            else:
                if self.offload_granularity == "block":
                    self.cuda_buffers[0].load_state_dict(blocks[0].state_dict(), 0, adapter_block_idx)
                else:
                    self.cuda_buffers[0].load_state_dict(blocks[0].compute_phases[0].state_dict(), 0, adapter_block_idx)
60
        self.init_stream.synchronize()
61
        self.need_init_first_buffer = False
62
63

    def prefetch_weights(self, block_idx, blocks, adapter_block_idx=None):
64
65
66
67
68
69
70
71
72
73
        with torch_device_module.stream(self.cuda_load_stream):
            if hasattr(self, "cpu_buffers"):
                self.cpu_buffers[1].load_state_dict_from_disk(block_idx, adapter_block_idx)
                self.cuda_buffers[1].load_state_dict(self.cpu_buffers[1].state_dict(), block_idx, adapter_block_idx)
            else:
                self.cuda_buffers[1].load_state_dict(blocks[block_idx].state_dict(), block_idx, adapter_block_idx)

    def prefetch_phase(self, block_idx, phase_idx, blocks, adapter_block_idx=None):
        with torch_device_module.stream(self.cuda_load_stream):
            if hasattr(self, "cpu_buffers"):
74
                self.cuda_buffers[phase_idx].load_state_dict(self.cpu_buffers[0][phase_idx].state_dict(), block_idx, adapter_block_idx)
75
76
            else:
                self.cuda_buffers[phase_idx].load_state_dict(blocks[block_idx].compute_phases[phase_idx].state_dict(), block_idx, adapter_block_idx)
gushiqiao's avatar
gushiqiao committed
77

78
    def swap_blocks(self):
79
        self.cuda_load_stream.synchronize()
80
81
82
83
        self.compute_stream.synchronize()
        self.cuda_buffers[0], self.cuda_buffers[1] = (
            self.cuda_buffers[1],
            self.cuda_buffers[0],
gushiqiao's avatar
gushiqiao committed
84
        )
85
86

    def swap_phases(self):
87
        self.cuda_load_stream.synchronize()
88
        self.compute_stream.synchronize()
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126

    @ExcludedProfilingContext("🔥 warm_up_cpu_buffers")
    def warm_up_cpu_buffers(self, blocks_num):
        logger.info("🔥 Warming up cpu buffers...")
        for i in tqdm(range(blocks_num)):
            for phase in self.cpu_buffers[0]:
                phase.load_state_dict_from_disk(i, None)
            for phase in self.cpu_buffers[1]:
                phase.load_state_dict_from_disk(i, None)

        for phase in self.cpu_buffers[0]:
            phase.load_state_dict_from_disk(0, None)
        for phase in self.cpu_buffers[1]:
            phase.load_state_dict_from_disk(1, None)
        logger.info("✅ CPU buffers warm-up completed.")

    def init_lazy_load(self, num_workers=6):
        self.lazy_load = True
        self.executor = ThreadPoolExecutor(max_workers=num_workers)
        self.prefetch_futures = []
        self.prefetch_block_idx = -1

    def start_prefetch_block(self, block_idx, adapter_block_idx=None):
        self.prefetch_block_idx = block_idx
        self.prefetch_futures = []
        for phase in self.cpu_buffers[1]:
            future = self.executor.submit(phase.load_state_dict_from_disk, block_idx, adapter_block_idx)
            self.prefetch_futures.append(future)

    def swap_cpu_buffers(self):
        wait_start = time.time()
        already_done = all(f.done() for f in self.prefetch_futures)
        for f in self.prefetch_futures:
            f.result()
        wait_time = time.time() - wait_start
        logger.debug(f"[Prefetch] block {self.prefetch_block_idx}: wait={wait_time:.3f}s, already_done={already_done}")
        self.cpu_buffers = [self.cpu_buffers[1], self.cpu_buffers[0]]

Gu Shiqiao's avatar
Gu Shiqiao committed
127
    def __del__(self):
128
        if hasattr(self, "executor") and self.executor is not None:
Gu Shiqiao's avatar
Gu Shiqiao committed
129
130
131
132
            for f in self.prefetch_futures:
                if not f.done():
                    f.result()
            self.executor.shutdown(wait=False)
133
134
            self.executor = None
            logger.debug("ThreadPoolExecutor shut down successfully.")