manager.py 6.1 KB
Newer Older
1
2
from concurrent.futures import ThreadPoolExecutor

PengGao's avatar
PengGao committed
3
import torch
4
from loguru import logger
Gu Shiqiao's avatar
Gu Shiqiao committed
5
from packaging.version import parse
6
from tqdm import tqdm
PengGao's avatar
PengGao committed
7

8
from lightx2v.utils.profiler import ExcludedProfilingContext
9
10
11
12
from lightx2v_platform.base.global_var import AI_DEVICE

torch_device_module = getattr(torch, AI_DEVICE)

gushiqiao's avatar
gushiqiao committed
13

14
class WeightAsyncStreamManager(object):
15
16
    def __init__(self, offload_granularity):
        self.offload_granularity = offload_granularity
17
18
        self.init_stream = torch_device_module.Stream(priority=0)
        self.need_init_first_buffer = True
19
        self.lazy_load = False
Gu Shiqiao's avatar
Gu Shiqiao committed
20
        torch_version = parse(torch.__version__.split("+")[0])
21
22
23
        if AI_DEVICE == "cuda" and torch_version >= parse("2.7"):
            self.cuda_load_stream = torch_device_module.Stream(priority=1)
            self.compute_stream = torch_device_module.Stream(priority=1)
Gu Shiqiao's avatar
Gu Shiqiao committed
24
        else:
25
26
27
28
29
30
31
32
33
34
35
36
37
            self.cuda_load_stream = torch_device_module.Stream(priority=0)
            self.compute_stream = torch_device_module.Stream(priority=-1)

    def init_cpu_buffer(self, blocks_cpu_buffer=None, phases_cpu_buffer=None):
        self.need_init_first_buffer = True
        if self.offload_granularity == "block":
            assert blocks_cpu_buffer is not None
            self.cpu_buffers = [blocks_cpu_buffer[i] for i in range(len(blocks_cpu_buffer))]
        elif self.offload_granularity == "phase":
            assert phases_cpu_buffer is not None
            self.cpu_buffers = [phases_cpu_buffer[i] for i in range(len(phases_cpu_buffer))]
        else:
            raise NotImplementedError
38
39

    def init_cuda_buffer(self, blocks_cuda_buffer=None, phases_cuda_buffer=None):
40
        self.need_init_first_buffer = True
41
42
43
44
45
46
47
48
        if self.offload_granularity == "block":
            assert blocks_cuda_buffer is not None
            self.cuda_buffers = [blocks_cuda_buffer[i] for i in range(len(blocks_cuda_buffer))]
        elif self.offload_granularity == "phase":
            assert phases_cuda_buffer is not None
            self.cuda_buffers = [phases_cuda_buffer[i] for i in range(len(phases_cuda_buffer))]
        else:
            raise NotImplementedError
49

50
    def init_first_buffer(self, blocks, adapter_block_idx=None):
51
52
        with torch_device_module.stream(self.init_stream):
            if hasattr(self, "cpu_buffers"):
53
                self.cuda_buffers[0].load_state_dict(self.cpu_buffers[0][0].state_dict(), 0, adapter_block_idx)
54
55
56
57
58
            else:
                if self.offload_granularity == "block":
                    self.cuda_buffers[0].load_state_dict(blocks[0].state_dict(), 0, adapter_block_idx)
                else:
                    self.cuda_buffers[0].load_state_dict(blocks[0].compute_phases[0].state_dict(), 0, adapter_block_idx)
59
        self.init_stream.synchronize()
60
        self.need_init_first_buffer = False
61
62

    def prefetch_weights(self, block_idx, blocks, adapter_block_idx=None):
63
64
65
66
67
68
69
70
71
72
        with torch_device_module.stream(self.cuda_load_stream):
            if hasattr(self, "cpu_buffers"):
                self.cpu_buffers[1].load_state_dict_from_disk(block_idx, adapter_block_idx)
                self.cuda_buffers[1].load_state_dict(self.cpu_buffers[1].state_dict(), block_idx, adapter_block_idx)
            else:
                self.cuda_buffers[1].load_state_dict(blocks[block_idx].state_dict(), block_idx, adapter_block_idx)

    def prefetch_phase(self, block_idx, phase_idx, blocks, adapter_block_idx=None):
        with torch_device_module.stream(self.cuda_load_stream):
            if hasattr(self, "cpu_buffers"):
73
                self.cuda_buffers[phase_idx].load_state_dict(self.cpu_buffers[0][phase_idx].state_dict(), block_idx, adapter_block_idx)
74
75
            else:
                self.cuda_buffers[phase_idx].load_state_dict(blocks[block_idx].compute_phases[phase_idx].state_dict(), block_idx, adapter_block_idx)
gushiqiao's avatar
gushiqiao committed
76

77
    def swap_blocks(self):
78
        self.cuda_load_stream.synchronize()
79
80
81
82
        self.compute_stream.synchronize()
        self.cuda_buffers[0], self.cuda_buffers[1] = (
            self.cuda_buffers[1],
            self.cuda_buffers[0],
gushiqiao's avatar
gushiqiao committed
83
        )
84
85

    def swap_phases(self):
86
        self.cuda_load_stream.synchronize()
87
        self.compute_stream.synchronize()
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117

    @ExcludedProfilingContext("🔥 warm_up_cpu_buffers")
    def warm_up_cpu_buffers(self, blocks_num):
        logger.info("🔥 Warming up cpu buffers...")
        for i in tqdm(range(blocks_num)):
            for phase in self.cpu_buffers[0]:
                phase.load_state_dict_from_disk(i, None)
            for phase in self.cpu_buffers[1]:
                phase.load_state_dict_from_disk(i, None)

        for phase in self.cpu_buffers[0]:
            phase.load_state_dict_from_disk(0, None)
        for phase in self.cpu_buffers[1]:
            phase.load_state_dict_from_disk(1, None)
        logger.info("✅ CPU buffers warm-up completed.")

    def init_lazy_load(self, num_workers=6):
        self.lazy_load = True
        self.executor = ThreadPoolExecutor(max_workers=num_workers)
        self.prefetch_futures = []
        self.prefetch_block_idx = -1

    def start_prefetch_block(self, block_idx, adapter_block_idx=None):
        self.prefetch_block_idx = block_idx
        self.prefetch_futures = []
        for phase in self.cpu_buffers[1]:
            future = self.executor.submit(phase.load_state_dict_from_disk, block_idx, adapter_block_idx)
            self.prefetch_futures.append(future)

    def swap_cpu_buffers(self):
Gu Shiqiao's avatar
Gu Shiqiao committed
118
119
        #  wait_start = time.time()
        # already_done = all(f.done() for f in self.prefetch_futures)
120
121
        for f in self.prefetch_futures:
            f.result()
Gu Shiqiao's avatar
Gu Shiqiao committed
122
123
        # wait_time = time.time() - wait_start
        # logger.debug(f"[Prefetch] block {self.prefetch_block_idx}: wait={wait_time:.3f}s, already_done={already_done}")
124
125
        self.cpu_buffers = [self.cpu_buffers[1], self.cpu_buffers[0]]

Gu Shiqiao's avatar
Gu Shiqiao committed
126
    def __del__(self):
127
        if hasattr(self, "executor") and self.executor is not None:
Gu Shiqiao's avatar
Gu Shiqiao committed
128
129
130
131
            for f in self.prefetch_futures:
                if not f.done():
                    f.result()
            self.executor.shutdown(wait=False)
132
133
            self.executor = None
            logger.debug("ThreadPoolExecutor shut down successfully.")