manager.py 6.73 KB
Newer Older
1
2
from concurrent.futures import ThreadPoolExecutor

PengGao's avatar
PengGao committed
3
import torch
4
from loguru import logger
Gu Shiqiao's avatar
Gu Shiqiao committed
5
from packaging.version import parse
6
from tqdm import tqdm
PengGao's avatar
PengGao committed
7

8
from lightx2v.utils.profiler import ExcludedProfilingContext
9
10
11
12
from lightx2v_platform.base.global_var import AI_DEVICE

torch_device_module = getattr(torch, AI_DEVICE)

gushiqiao's avatar
gushiqiao committed
13

14
class WeightAsyncStreamManager(object):
15
16
    def __init__(self, offload_granularity):
        self.offload_granularity = offload_granularity
17
18
        self.init_stream = torch_device_module.Stream(priority=0)
        self.need_init_first_buffer = True
19
        self.lazy_load = False
Gu Shiqiao's avatar
Gu Shiqiao committed
20
        torch_version = parse(torch.__version__.split("+")[0])
21
22
23
        if AI_DEVICE == "cuda" and torch_version >= parse("2.7"):
            self.cuda_load_stream = torch_device_module.Stream(priority=1)
            self.compute_stream = torch_device_module.Stream(priority=1)
Gu Shiqiao's avatar
Gu Shiqiao committed
24
        else:
25
26
27
28
29
30
31
32
33
34
35
36
37
            self.cuda_load_stream = torch_device_module.Stream(priority=0)
            self.compute_stream = torch_device_module.Stream(priority=-1)

    def init_cpu_buffer(self, blocks_cpu_buffer=None, phases_cpu_buffer=None):
        self.need_init_first_buffer = True
        if self.offload_granularity == "block":
            assert blocks_cpu_buffer is not None
            self.cpu_buffers = [blocks_cpu_buffer[i] for i in range(len(blocks_cpu_buffer))]
        elif self.offload_granularity == "phase":
            assert phases_cpu_buffer is not None
            self.cpu_buffers = [phases_cpu_buffer[i] for i in range(len(phases_cpu_buffer))]
        else:
            raise NotImplementedError
38
39

    def init_cuda_buffer(self, blocks_cuda_buffer=None, phases_cuda_buffer=None):
40
        self.need_init_first_buffer = True
41
42
43
44
45
46
47
48
        if self.offload_granularity == "block":
            assert blocks_cuda_buffer is not None
            self.cuda_buffers = [blocks_cuda_buffer[i] for i in range(len(blocks_cuda_buffer))]
        elif self.offload_granularity == "phase":
            assert phases_cuda_buffer is not None
            self.cuda_buffers = [phases_cuda_buffer[i] for i in range(len(phases_cuda_buffer))]
        else:
            raise NotImplementedError
49

50
    def init_first_buffer(self, blocks, adapter_block_idx=None):
51
52
        with torch_device_module.stream(self.init_stream):
            if hasattr(self, "cpu_buffers"):
53
                self.cuda_buffers[0].load_state_dict(self.cpu_buffers[0][0].state_dict(), 0, adapter_block_idx)
54
55
56
57
58
            else:
                if self.offload_granularity == "block":
                    self.cuda_buffers[0].load_state_dict(blocks[0].state_dict(), 0, adapter_block_idx)
                else:
                    self.cuda_buffers[0].load_state_dict(blocks[0].compute_phases[0].state_dict(), 0, adapter_block_idx)
59
        self.init_stream.synchronize()
60
        self.need_init_first_buffer = False
61
62

    def prefetch_weights(self, block_idx, blocks, adapter_block_idx=None):
63
64
65
66
67
68
69
70
71
72
        with torch_device_module.stream(self.cuda_load_stream):
            if hasattr(self, "cpu_buffers"):
                self.cpu_buffers[1].load_state_dict_from_disk(block_idx, adapter_block_idx)
                self.cuda_buffers[1].load_state_dict(self.cpu_buffers[1].state_dict(), block_idx, adapter_block_idx)
            else:
                self.cuda_buffers[1].load_state_dict(blocks[block_idx].state_dict(), block_idx, adapter_block_idx)

    def prefetch_phase(self, block_idx, phase_idx, blocks, adapter_block_idx=None):
        with torch_device_module.stream(self.cuda_load_stream):
            if hasattr(self, "cpu_buffers"):
73
                self.cuda_buffers[phase_idx].load_state_dict(self.cpu_buffers[0][phase_idx].state_dict(), block_idx, adapter_block_idx)
74
75
            else:
                self.cuda_buffers[phase_idx].load_state_dict(blocks[block_idx].compute_phases[phase_idx].state_dict(), block_idx, adapter_block_idx)
gushiqiao's avatar
gushiqiao committed
76

77
    def swap_blocks(self):
78
        self.cuda_load_stream.synchronize()
79
80
81
82
        self.compute_stream.synchronize()
        self.cuda_buffers[0], self.cuda_buffers[1] = (
            self.cuda_buffers[1],
            self.cuda_buffers[0],
gushiqiao's avatar
gushiqiao committed
83
        )
84
85

    def swap_phases(self):
86
        self.cuda_load_stream.synchronize()
87
        self.compute_stream.synchronize()
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149

    @ExcludedProfilingContext("🔥 warm_up_cpu_buffers")
    def warm_up_cpu_buffers(self, blocks_num):
        logger.info("🔥 Warming up cpu buffers...")
        for i in tqdm(range(blocks_num)):
            for phase in self.cpu_buffers[0]:
                phase.load_state_dict_from_disk(i, None)
            for phase in self.cpu_buffers[1]:
                phase.load_state_dict_from_disk(i, None)

        for phase in self.cpu_buffers[0]:
            phase.load_state_dict_from_disk(0, None)
        for phase in self.cpu_buffers[1]:
            phase.load_state_dict_from_disk(1, None)
        logger.info("✅ CPU buffers warm-up completed.")

    def init_lazy_load(self, num_workers=6):
        self.lazy_load = True
        self.executor = ThreadPoolExecutor(max_workers=num_workers)
        self.prefetch_futures = []
        self.prefetch_block_idx = -1

    def start_prefetch_block(self, block_idx, adapter_block_idx=None):
        self.prefetch_block_idx = block_idx
        self.prefetch_futures = []
        for phase in self.cpu_buffers[1]:
            future = self.executor.submit(phase.load_state_dict_from_disk, block_idx, adapter_block_idx)
            self.prefetch_futures.append(future)

    def swap_cpu_buffers(self):
        import time

        wait_start = time.time()
        already_done = all(f.done() for f in self.prefetch_futures)
        for f in self.prefetch_futures:
            f.result()
        wait_time = time.time() - wait_start
        logger.debug(f"[Prefetch] block {self.prefetch_block_idx}: wait={wait_time:.3f}s, already_done={already_done}")
        self.cpu_buffers = [self.cpu_buffers[1], self.cpu_buffers[0]]

    def shutdown(self, wait=True):
        """Shutdown the thread pool executor and wait for all pending tasks to complete."""
        if hasattr(self, "executor") and self.executor is not None:
            # Wait for all pending futures to complete before shutting down
            if hasattr(self, "prefetch_futures"):
                for f in self.prefetch_futures:
                    try:
                        if not f.done():
                            f.result()
                    except Exception:
                        pass
            self.executor.shutdown(wait=wait)
            self.executor = None
            logger.debug("ThreadPoolExecutor shut down successfully.")

    def __del__(self):
        """Cleanup method to ensure executor is shut down when object is destroyed."""
        try:
            if hasattr(self, "executor") and self.executor is not None:
                self.executor.shutdown(wait=False)
        except Exception:
            pass