wrap.py 2.35 KB
Newer Older
helloyongyang's avatar
helloyongyang committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import functools
from lightx2v.attentions.distributed.ulysses.attn import ulysses_attn
from lightx2v.attentions.distributed.utils.process import pre_process, post_process

def parallelize_hunyuan(hunyuan_model):
    """将 Hunyuan 模型的推理过程并行化,使用 Ulysses 注意力机制。

    参数:
        hunyuan_model: Hunyuan 模型实例,包含推理方法和其他属性。
    """
    # 将 Hunyuan 模型的并行注意力机制替换为 Ulysses 注意力
    hunyuan_model.transformer_infer.parallel_attention = ulysses_attn

    # 保存原始的推理方法,以便后续调用
    original_infer = hunyuan_model.infer

    @functools.wraps(hunyuan_model.__class__.infer)  # 保留原始推理方法的元信息
    def new_infer(self, text_encoders_output, args):
        """新的推理方法,处理输入并调用原始推理方法。

        参数:
            self: Hunyuan 模型实例
            text_encoders_output: 文本编码器的输出
            args: 其他参数

        返回:
            None
        """
        # 保存原始的潜在模型输入和频率数据
        self.scheduler.ori_latents, self.scheduler.ori_freqs_cos, self.scheduler.ori_freqs_sin = (
            self.scheduler.latents, 
            self.scheduler.freqs_cos, 
            self.scheduler.freqs_sin
        )
        
        # 预处理输入数据以适应并行计算
        self.scheduler.latents, self.scheduler.freqs_cos, self.scheduler.freqs_sin, split_dim = pre_process(
            self.scheduler.latents, self.scheduler.freqs_cos, self.scheduler.freqs_sin
        )

        # 调用原始推理方法,获取输出
        output = original_infer(
            text_encoders_output, args
        )

        # 对输出进行后处理
        self.scheduler.noise_pred = post_process(self.scheduler.noise_pred, split_dim)
        
        # 恢复原始的潜在模型输入和频率数据
        self.scheduler.latents, self.scheduler.freqs_cos, self.scheduler.freqs_sin = (
            self.scheduler.ori_latents, 
            self.scheduler.ori_freqs_cos, 
            self.scheduler.ori_freqs_sin
        )

        # return combined_output  # 返回处理后的输出(当前被注释掉)

    # 将新的推理方法绑定到 Hunyuan 模型实例
    new_infer = new_infer.__get__(hunyuan_model)
    hunyuan_model.infer = new_infer  # 替换原始推理方法