manager.py 4.04 KB
Newer Older
PengGao's avatar
PengGao committed
1
import torch
Gu Shiqiao's avatar
Gu Shiqiao committed
2
from packaging.version import parse
PengGao's avatar
PengGao committed
3

4
5
6
7
from lightx2v_platform.base.global_var import AI_DEVICE

torch_device_module = getattr(torch, AI_DEVICE)

gushiqiao's avatar
gushiqiao committed
8

9
class WeightAsyncStreamManager(object):
10
11
    def __init__(self, offload_granularity):
        self.offload_granularity = offload_granularity
12
13
        self.init_stream = torch_device_module.Stream(priority=0)
        self.need_init_first_buffer = True
Gu Shiqiao's avatar
Gu Shiqiao committed
14
        torch_version = parse(torch.__version__.split("+")[0])
15
16
17
        if AI_DEVICE == "cuda" and torch_version >= parse("2.7"):
            self.cuda_load_stream = torch_device_module.Stream(priority=1)
            self.compute_stream = torch_device_module.Stream(priority=1)
Gu Shiqiao's avatar
Gu Shiqiao committed
18
        else:
19
20
21
22
23
24
25
26
27
28
29
30
31
            self.cuda_load_stream = torch_device_module.Stream(priority=0)
            self.compute_stream = torch_device_module.Stream(priority=-1)

    def init_cpu_buffer(self, blocks_cpu_buffer=None, phases_cpu_buffer=None):
        self.need_init_first_buffer = True
        if self.offload_granularity == "block":
            assert blocks_cpu_buffer is not None
            self.cpu_buffers = [blocks_cpu_buffer[i] for i in range(len(blocks_cpu_buffer))]
        elif self.offload_granularity == "phase":
            assert phases_cpu_buffer is not None
            self.cpu_buffers = [phases_cpu_buffer[i] for i in range(len(phases_cpu_buffer))]
        else:
            raise NotImplementedError
32
33

    def init_cuda_buffer(self, blocks_cuda_buffer=None, phases_cuda_buffer=None):
34
        self.need_init_first_buffer = True
35
36
37
38
39
40
41
42
        if self.offload_granularity == "block":
            assert blocks_cuda_buffer is not None
            self.cuda_buffers = [blocks_cuda_buffer[i] for i in range(len(blocks_cuda_buffer))]
        elif self.offload_granularity == "phase":
            assert phases_cuda_buffer is not None
            self.cuda_buffers = [phases_cuda_buffer[i] for i in range(len(phases_cuda_buffer))]
        else:
            raise NotImplementedError
43

44
    def init_first_buffer(self, blocks, adapter_block_idx=None):
45
46
47
48
49
50
51
52
        with torch_device_module.stream(self.init_stream):
            if hasattr(self, "cpu_buffers"):
                self.cuda_buffers[0].load_state_dict(self.cpu_buffers[0].state_dict(), 0, adapter_block_idx)
            else:
                if self.offload_granularity == "block":
                    self.cuda_buffers[0].load_state_dict(blocks[0].state_dict(), 0, adapter_block_idx)
                else:
                    self.cuda_buffers[0].load_state_dict(blocks[0].compute_phases[0].state_dict(), 0, adapter_block_idx)
53
        self.init_stream.synchronize()
54
        self.need_init_first_buffer = False
55
56

    def prefetch_weights(self, block_idx, blocks, adapter_block_idx=None):
57
58
59
60
61
62
63
64
65
66
67
68
69
70
        with torch_device_module.stream(self.cuda_load_stream):
            if hasattr(self, "cpu_buffers"):
                self.cpu_buffers[1].load_state_dict_from_disk(block_idx, adapter_block_idx)
                self.cuda_buffers[1].load_state_dict(self.cpu_buffers[1].state_dict(), block_idx, adapter_block_idx)
            else:
                self.cuda_buffers[1].load_state_dict(blocks[block_idx].state_dict(), block_idx, adapter_block_idx)

    def prefetch_phase(self, block_idx, phase_idx, blocks, adapter_block_idx=None):
        with torch_device_module.stream(self.cuda_load_stream):
            if hasattr(self, "cpu_buffers"):
                self.cpu_buffers[phase_idx].load_state_dict_from_disk(block_idx, adapter_block_idx)
                self.cuda_buffers[phase_idx].load_state_dict(self.cpu_buffers[phase_idx].state_dict(), block_idx, adapter_block_idx)
            else:
                self.cuda_buffers[phase_idx].load_state_dict(blocks[block_idx].compute_phases[phase_idx].state_dict(), block_idx, adapter_block_idx)
gushiqiao's avatar
gushiqiao committed
71

72
    def swap_blocks(self):
73
        self.cuda_load_stream.synchronize()
74
75
76
77
        self.compute_stream.synchronize()
        self.cuda_buffers[0], self.cuda_buffers[1] = (
            self.cuda_buffers[1],
            self.cuda_buffers[0],
gushiqiao's avatar
gushiqiao committed
78
        )
79
80

    def swap_phases(self):
81
        self.cuda_load_stream.synchronize()
82
        self.compute_stream.synchronize()