me_block.py 3.09 KB
Newer Older
helloyongyang's avatar
helloyongyang committed
1
2
3
4
5
6
7
8
9
10
import torch
import torch.nn as nn


class MemoryEfficientBlocks(nn.Module):
    def __init__(self, block_class, num_blocks, **block_params):
        super().__init__()
        self.block_class = block_class
        self.num_blocks = num_blocks
        self.block_params = block_params
Dongz's avatar
Dongz committed
11

helloyongyang's avatar
helloyongyang committed
12
        # 初始化两个block
Dongz's avatar
Dongz committed
13
14
        self.active_blocks = nn.ModuleList([block_class(**block_params) for _ in range(2)])

helloyongyang's avatar
helloyongyang committed
15
16
        # 为权重加载创建独立的CUDA流,并设置优先级
        self.compute_stream = torch.cuda.Stream(priority=-1)  # 高优先级
Dongz's avatar
Dongz committed
17
        self.load_stream = torch.cuda.Stream(priority=0)  # 普通优先级
helloyongyang's avatar
helloyongyang committed
18
19
20
21

        # 预分配固定内存用于异步传输
        self.pinned_memory = torch.cuda.empty_cache()
        torch.cuda.memory.set_per_process_memory_fraction(0.8)  # 限制GPU内存使用
Dongz's avatar
Dongz committed
22

helloyongyang's avatar
helloyongyang committed
23
24
25
26
        # 用于存储预加载的权重
        # self.next_weights = None
        self.weight_buffer = []
        # self.current_block_idx = 0
Dongz's avatar
Dongz committed
27

helloyongyang's avatar
helloyongyang committed
28
29
30
31
    def initialize_weights(self, checkpoint, key):
        """加载所有权重到CPU内存"""
        # checkpoint = torch.load(checkpoint_path, map_location='cpu')
        for i in range(self.num_blocks):
Dongz's avatar
Dongz committed
32
            block_weights = {k.replace(f"{key}.{i}.", ""): v for k, v in checkpoint.items() if f"{key}.{i}." in k}
helloyongyang's avatar
helloyongyang committed
33
            self.weight_buffer.append(block_weights)
Dongz's avatar
Dongz committed
34

helloyongyang's avatar
helloyongyang committed
35
36
37
38
    def prefetch_weights(self, block_idx):
        """在独立CUDA流中预加载下一个block的权重"""
        with torch.cuda.stream(self.load_stream):
            next_weights = self.weight_buffer[block_idx]
Dongz's avatar
Dongz committed
39
            next_weights = {k: v.cuda(non_blocking=True) for k, v in next_weights.items()}
helloyongyang's avatar
helloyongyang committed
40
            self.active_blocks[1].load_state_dict(next_weights)
Dongz's avatar
Dongz committed
41

helloyongyang's avatar
helloyongyang committed
42
43
44
45
46
47
    def swap_blocks(self):
        """交换两个block并更新权重"""
        # 等待计算完成
        self.compute_stream.synchronize()
        # 等待加载完成
        self.load_stream.synchronize()
Dongz's avatar
Dongz committed
48

helloyongyang's avatar
helloyongyang committed
49
        # 交换blocks
Dongz's avatar
Dongz committed
50
51
        self.active_blocks[0], self.active_blocks[1] = self.active_blocks[1], self.active_blocks[0]

helloyongyang's avatar
helloyongyang committed
52
53
54
55
56
57
    def forward(self, *args, **kwargs):
        """前向传播,同时进行计算和权重加载"""
        # import pdb; pdb.set_trace()
        for i in range(self.num_blocks):
            if i == 0:
                self.active_blocks[0].load_state_dict(self.weight_buffer[0])
Dongz's avatar
Dongz committed
58

helloyongyang's avatar
helloyongyang committed
59
60
61
62
63
64
65
66
67
            # 在主计算流中进行当前block的计算
            with torch.cuda.stream(self.compute_stream):
                current_block = self.active_blocks[0]
                outputs = current_block(*args, **kwargs)  # 解包参数传入
            # import pdb; pdb.set_trace()

            # 在独立流中预加载下一个block的权重
            if i < self.num_blocks - 1:
                self.prefetch_weights(i + 1)
Dongz's avatar
Dongz committed
68

helloyongyang's avatar
helloyongyang committed
69
70
            # 交换blocks并更新权重
            self.swap_blocks()
Dongz's avatar
Dongz committed
71

helloyongyang's avatar
helloyongyang committed
72
73
74
75
76
77
78
79
            # 更新args中的输入为当前输出
            args = list(args)
            if len(outputs) == 1:
                args[0] = outputs
            else:
                for i in range(len(outputs)):
                    args[i] = outputs[i]
            args = tuple(args)
Dongz's avatar
Dongz committed
80
81

        return outputs