wrap.py 2.99 KB
Newer Older
helloyongyang's avatar
helloyongyang committed
1
2
3
import functools
from lightx2v.attentions.distributed.ulysses.attn import ulysses_attn

Dongz's avatar
Dongz committed
4

helloyongyang's avatar
helloyongyang committed
5
def parallelize_hunyuan(hunyuan_model):
Xinchi Huang's avatar
Xinchi Huang committed
6
    from lightx2v.attentions.distributed.utils.hunyuan.processor import pre_process, post_process
Dongz's avatar
Dongz committed
7

helloyongyang's avatar
helloyongyang committed
8
9
10
11
12
13
14
15
16
17
18
19
    """将 Hunyuan 模型的推理过程并行化,使用 Ulysses 注意力机制。

    参数:
        hunyuan_model: Hunyuan 模型实例,包含推理方法和其他属性。
    """
    # 将 Hunyuan 模型的并行注意力机制替换为 Ulysses 注意力
    hunyuan_model.transformer_infer.parallel_attention = ulysses_attn

    # 保存原始的推理方法,以便后续调用
    original_infer = hunyuan_model.infer

    @functools.wraps(hunyuan_model.__class__.infer)  # 保留原始推理方法的元信息
Xinchi Huang's avatar
Xinchi Huang committed
20
    def new_infer(self, text_encoders_output, image_encoder_output, args):
helloyongyang's avatar
helloyongyang committed
21
22
23
24
25
26
27
28
29
30
31
        """新的推理方法,处理输入并调用原始推理方法。

        参数:
            self: Hunyuan 模型实例
            text_encoders_output: 文本编码器的输出
            args: 其他参数

        返回:
            None
        """
        # 保存原始的潜在模型输入和频率数据
Dongz's avatar
Dongz committed
32
33
        self.scheduler.ori_latents, self.scheduler.ori_freqs_cos, self.scheduler.ori_freqs_sin = (self.scheduler.latents, self.scheduler.freqs_cos, self.scheduler.freqs_sin)

helloyongyang's avatar
helloyongyang committed
34
        # 预处理输入数据以适应并行计算
Dongz's avatar
Dongz committed
35
        self.scheduler.latents, self.scheduler.freqs_cos, self.scheduler.freqs_sin, split_dim = pre_process(self.scheduler.latents, self.scheduler.freqs_cos, self.scheduler.freqs_sin)
helloyongyang's avatar
helloyongyang committed
36
37

        # 调用原始推理方法,获取输出
Dongz's avatar
Dongz committed
38
        original_infer(text_encoders_output, image_encoder_output, args)
helloyongyang's avatar
helloyongyang committed
39
40
41

        # 对输出进行后处理
        self.scheduler.noise_pred = post_process(self.scheduler.noise_pred, split_dim)
Dongz's avatar
Dongz committed
42

helloyongyang's avatar
helloyongyang committed
43
        # 恢复原始的潜在模型输入和频率数据
Dongz's avatar
Dongz committed
44
        self.scheduler.latents, self.scheduler.freqs_cos, self.scheduler.freqs_sin = (self.scheduler.ori_latents, self.scheduler.ori_freqs_cos, self.scheduler.ori_freqs_sin)
helloyongyang's avatar
helloyongyang committed
45
46
47
48
49

        # return combined_output  # 返回处理后的输出(当前被注释掉)

    # 将新的推理方法绑定到 Hunyuan 模型实例
    new_infer = new_infer.__get__(hunyuan_model)
Xinchi Huang's avatar
Xinchi Huang committed
50
51
52
53
54
    hunyuan_model.infer = new_infer  # 替换原始推理方法


def parallelize_wan(wan_model):
    from lightx2v.attentions.distributed.utils.wan.processor import pre_process, post_process
Dongz's avatar
Dongz committed
55

Xinchi Huang's avatar
Xinchi Huang committed
56
57
58
59
60
61
    wan_model.transformer_infer.parallel_attention = ulysses_attn

    original_infer = wan_model.transformer_infer.infer

    @functools.wraps(wan_model.transformer_infer.__class__.infer)  # 保留原始推理方法的元信息
    def new_infer(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context):
Dongz's avatar
Dongz committed
62
        x = pre_process(x)
Xinchi Huang's avatar
Xinchi Huang committed
63
64
65

        x = original_infer(weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context)

Dongz's avatar
Dongz committed
66
        x = post_process(x)
Xinchi Huang's avatar
Xinchi Huang committed
67
68
69
70

        return x

    new_infer = new_infer.__get__(wan_model.transformer_infer)
Dongz's avatar
Dongz committed
71
    wan_model.transformer_infer.infer = new_infer  # 替换原始推理方法