wrap.py 3.18 KB
Newer Older
helloyongyang's avatar
helloyongyang committed
1
2
3
4
import functools
from lightx2v.attentions.distributed.ulysses.attn import ulysses_attn

def parallelize_hunyuan(hunyuan_model):
Xinchi Huang's avatar
Xinchi Huang committed
5
    from lightx2v.attentions.distributed.utils.hunyuan.processor import pre_process, post_process
helloyongyang's avatar
helloyongyang committed
6
7
8
9
10
11
12
13
14
15
16
17
    """将 Hunyuan 模型的推理过程并行化,使用 Ulysses 注意力机制。

    参数:
        hunyuan_model: Hunyuan 模型实例,包含推理方法和其他属性。
    """
    # 将 Hunyuan 模型的并行注意力机制替换为 Ulysses 注意力
    hunyuan_model.transformer_infer.parallel_attention = ulysses_attn

    # 保存原始的推理方法,以便后续调用
    original_infer = hunyuan_model.infer

    @functools.wraps(hunyuan_model.__class__.infer)  # 保留原始推理方法的元信息
Xinchi Huang's avatar
Xinchi Huang committed
18
    def new_infer(self, text_encoders_output, image_encoder_output, args):
helloyongyang's avatar
helloyongyang committed
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
        """新的推理方法,处理输入并调用原始推理方法。

        参数:
            self: Hunyuan 模型实例
            text_encoders_output: 文本编码器的输出
            args: 其他参数

        返回:
            None
        """
        # 保存原始的潜在模型输入和频率数据
        self.scheduler.ori_latents, self.scheduler.ori_freqs_cos, self.scheduler.ori_freqs_sin = (
            self.scheduler.latents, 
            self.scheduler.freqs_cos, 
            self.scheduler.freqs_sin
        )
        
        # 预处理输入数据以适应并行计算
        self.scheduler.latents, self.scheduler.freqs_cos, self.scheduler.freqs_sin, split_dim = pre_process(
            self.scheduler.latents, self.scheduler.freqs_cos, self.scheduler.freqs_sin
        )

        # 调用原始推理方法,获取输出
Xinchi Huang's avatar
Xinchi Huang committed
42
43
        original_infer(
            text_encoders_output, image_encoder_output, args
helloyongyang's avatar
helloyongyang committed
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
        )

        # 对输出进行后处理
        self.scheduler.noise_pred = post_process(self.scheduler.noise_pred, split_dim)
        
        # 恢复原始的潜在模型输入和频率数据
        self.scheduler.latents, self.scheduler.freqs_cos, self.scheduler.freqs_sin = (
            self.scheduler.ori_latents, 
            self.scheduler.ori_freqs_cos, 
            self.scheduler.ori_freqs_sin
        )

        # return combined_output  # 返回处理后的输出(当前被注释掉)

    # 将新的推理方法绑定到 Hunyuan 模型实例
    new_infer = new_infer.__get__(hunyuan_model)
Xinchi Huang's avatar
Xinchi Huang committed
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
    hunyuan_model.infer = new_infer  # 替换原始推理方法


def parallelize_wan(wan_model):
    from lightx2v.attentions.distributed.utils.wan.processor import pre_process, post_process
    wan_model.transformer_infer.parallel_attention = ulysses_attn

    original_infer = wan_model.transformer_infer.infer

    @functools.wraps(wan_model.transformer_infer.__class__.infer)  # 保留原始推理方法的元信息
    def new_infer(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context):

        x = pre_process(
            x
        )

        x = original_infer(weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context)

        x = post_process(
            x
        )

        return x

    new_infer = new_infer.__get__(wan_model.transformer_infer)
    wan_model.transformer_infer.infer = new_infer  # 替换原始推理方法