wrap.py 2.98 KB
Newer Older
helloyongyang's avatar
helloyongyang committed
1
2
3
4
5
import functools
from lightx2v.attentions.distributed.ring.attn import ring_attn


def parallelize_hunyuan(hunyuan_model):
Xinchi Huang's avatar
Xinchi Huang committed
6
7
    from lightx2v.attentions.distributed.utils.hunyuan.processor import pre_process, post_process

helloyongyang's avatar
helloyongyang committed
8
9
10
11
12
13
14
15
16
17
18
19
    """将 Hunyuan 模型的推理过程并行化,使用 Ulysses 注意力机制。

    参数:
        hunyuan_model: Hunyuan 模型实例,包含推理方法和其他属性。
    """
    # 将 Hunyuan 模型的并行注意力机制替换为 Ulysses 注意力
    hunyuan_model.transformer_infer.parallel_attention = ring_attn

    # 保存原始的推理方法,以便后续调用
    original_infer = hunyuan_model.infer

    @functools.wraps(hunyuan_model.__class__.infer)  # 保留原始推理方法的元信息
Xinchi Huang's avatar
Xinchi Huang committed
20
    def new_infer(self, text_encoders_output, image_encoder_output, args):
helloyongyang's avatar
helloyongyang committed
21
22
23
24
        """新的推理方法,处理输入并调用原始推理方法。

        参数:
            self: Hunyuan 模型实例
Xinchi Huang's avatar
Xinchi Huang committed
25
26
            text_encoders_output: 文本编码器的输出
            args: 其他参数
helloyongyang's avatar
helloyongyang committed
27
28

        返回:
Xinchi Huang's avatar
Xinchi Huang committed
29
            None
helloyongyang's avatar
helloyongyang committed
30
        """
Xinchi Huang's avatar
Xinchi Huang committed
31
32
33
34
35
        # 保存原始的潜在模型输入和频率数据
        self.scheduler.ori_latents, self.scheduler.ori_freqs_cos, self.scheduler.ori_freqs_sin = (self.scheduler.latents, self.scheduler.freqs_cos, self.scheduler.freqs_sin)

        # 预处理输入数据以适应并行计算
        self.scheduler.latents, self.scheduler.freqs_cos, self.scheduler.freqs_sin, split_dim = pre_process(self.scheduler.latents, self.scheduler.freqs_cos, self.scheduler.freqs_sin)
helloyongyang's avatar
helloyongyang committed
36
37

        # 调用原始推理方法,获取输出
Xinchi Huang's avatar
Xinchi Huang committed
38
        original_infer(text_encoders_output, image_encoder_output, args)
helloyongyang's avatar
helloyongyang committed
39
40

        # 对输出进行后处理
Xinchi Huang's avatar
Xinchi Huang committed
41
42
43
44
        self.scheduler.noise_pred = post_process(self.scheduler.noise_pred, split_dim)

        # 恢复原始的潜在模型输入和频率数据
        self.scheduler.latents, self.scheduler.freqs_cos, self.scheduler.freqs_sin = (self.scheduler.ori_latents, self.scheduler.ori_freqs_cos, self.scheduler.ori_freqs_sin)
helloyongyang's avatar
helloyongyang committed
45

Xinchi Huang's avatar
Xinchi Huang committed
46
        # return combined_output  # 返回处理后的输出(当前被注释掉)
helloyongyang's avatar
helloyongyang committed
47
48
49

    # 将新的推理方法绑定到 Hunyuan 模型实例
    new_infer = new_infer.__get__(hunyuan_model)
Dongz's avatar
Dongz committed
50
    hunyuan_model.infer = new_infer  # 替换原始推理方法
Xinchi Huang's avatar
Xinchi Huang committed
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71


def parallelize_wan(wan_model):
    from lightx2v.attentions.distributed.utils.wan.processor import pre_process, post_process

    wan_model.transformer_infer.parallel_attention = ring_attn

    original_infer = wan_model.transformer_infer.infer

    @functools.wraps(wan_model.transformer_infer.__class__.infer)  # 保留原始推理方法的元信息
    def new_infer(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context):
        x = pre_process(x)

        x = original_infer(weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context)

        x = post_process(x)

        return x

    new_infer = new_infer.__get__(wan_model.transformer_infer)
    wan_model.transformer_infer.infer = new_infer  # 替换原始推理方法