model.py 7.21 KB
Newer Older
litzh's avatar
litzh committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
import torch
import torch.distributed as dist

from lightx2v.models.networks.base_model import BaseTransformerModel
from lightx2v.models.networks.z_image.infer.offload.transformer_infer import ZImageOffloadTransformerInfer
from lightx2v.models.networks.z_image.infer.post_infer import ZImagePostInfer
from lightx2v.models.networks.z_image.infer.pre_infer import ZImagePreInfer
from lightx2v.models.networks.z_image.infer.transformer_infer import ZImageTransformerInfer
from lightx2v.models.networks.z_image.weights.post_weights import ZImagePostWeights
from lightx2v.models.networks.z_image.weights.pre_weights import ZImagePreWeights
from lightx2v.models.networks.z_image.weights.transformer_weights import ZImageTransformerWeights
from lightx2v.utils.custom_compiler import compiled_method
from lightx2v.utils.envs import *
from lightx2v.utils.utils import *


class ZImageTransformerModel(BaseTransformerModel):
    pre_weight_class = ZImagePreWeights
    transformer_weight_class = ZImageTransformerWeights
    post_weight_class = ZImagePostWeights

    def __init__(self, model_path, config, device, lora_path=None, lora_strength=1.0):
        super().__init__(model_path, config, device, None, lora_path, lora_strength)
        if self.lazy_load:
            self.remove_keys.extend(["layers."])

        if self.config["seq_parallel"]:
            raise NotImplementedError("Sequence parallel is not implemented for ZImageTransformerModel")

        self._init_infer_class()
        self._init_weights()
        self._init_infer()

    def _init_infer_class(self):
        if self.config["feature_caching"] == "NoCaching":
            self.transformer_infer_class = ZImageTransformerInfer if not self.cpu_offload else ZImageOffloadTransformerInfer
        else:
            assert NotImplementedError
        self.pre_infer_class = ZImagePreInfer
        self.post_infer_class = ZImagePostInfer

    def _init_infer(self):
        self.transformer_infer = self.transformer_infer_class(self.config)
        self.pre_infer = self.pre_infer_class(self.config)
        self.post_infer = self.post_infer_class(self.config)
        if hasattr(self.transformer_infer, "offload_manager"):
            self._init_offload_manager()

    @torch.no_grad()
    def _infer_cond_uncond(self, latents_input, prompt_embeds, infer_condition=True):
        self.scheduler.infer_condition = infer_condition
        pre_infer_out = self.pre_infer.infer(
            weights=self.pre_weight,
            hidden_states=latents_input,
            encoder_hidden_states=prompt_embeds,
        )

        if self.config["seq_parallel"]:
            pre_infer_out = self._seq_parallel_pre_process(pre_infer_out)

        hidden_states = self.transformer_infer.infer(
            block_weights=self.transformer_weights,
            pre_infer_out=pre_infer_out,
        )
        noise_pred = self.post_infer.infer(
            self.post_weight,
            hidden_states,
            pre_infer_out.temb_img_silu,  # Use timestep embedding (t), not text embedding!
            image_tokens_len=pre_infer_out.image_tokens_len,
        )

        if self.config["seq_parallel"]:
            noise_pred = self._seq_parallel_post_process(noise_pred)
        return noise_pred

    @torch.no_grad()
    def _seq_parallel_pre_process(self, pre_infer_out):
        raise NotImplementedError("Sequence parallel pre-process is not implemented for ZImageTransformerModel")

    @torch.no_grad()
    def _seq_parallel_post_process(self, noise_pred):
        raise NotImplementedError("Sequence parallel post-process is not implemented for ZImageTransformerModel")

    @compiled_method()
    @torch.no_grad()
    def infer(self, inputs):
        if self.cpu_offload:
            if self.offload_granularity == "model" and self.scheduler.step_index == 0:
                self.to_cuda()
            elif self.offload_granularity != "model":
                self.pre_weight.to_cuda()
                self.post_weight.to_cuda()
                self.transformer_weights.non_block_weights_to_cuda()

        latents = self.scheduler.latents
        latents_input = latents

        if self.config["enable_cfg"]:
            if self.config["cfg_parallel"]:
                # ==================== CFG Parallel Processing ====================
                cfg_p_group = self.config["device_mesh"].get_group(mesh_dim="cfg_p")
                assert dist.get_world_size(cfg_p_group) == 2, "cfg_p_world_size must be equal to 2"
                cfg_p_rank = dist.get_rank(cfg_p_group)

                if cfg_p_rank == 0:
                    noise_pred = self._infer_cond_uncond(latents_input, inputs["text_encoder_output"]["prompt_embeds"], infer_condition=True)
                else:
                    noise_pred = self._infer_cond_uncond(latents_input, inputs["text_encoder_output"]["negative_prompt_embeds"], infer_condition=False)

                # post_infer already extracts image part, so noise_pred is already [B, T_img, out_dim]
                # No need to extract again
                noise_pred_list = [torch.zeros_like(noise_pred) for _ in range(2)]
                dist.all_gather(noise_pred_list, noise_pred, group=cfg_p_group)
                noise_pred_cond = noise_pred_list[0]  # cfg_p_rank == 0
                noise_pred_uncond = noise_pred_list[1]  # cfg_p_rank == 1
            else:
                # ==================== CFG Processing ====================
                noise_pred_cond = self._infer_cond_uncond(latents_input, inputs["text_encoder_output"]["prompt_embeds"], infer_condition=True)
                noise_pred_uncond = self._infer_cond_uncond(latents_input, inputs["text_encoder_output"]["negative_prompt_embeds"], infer_condition=False)

                # post_infer already extracts image part, so noise_pred is already [B, T_img, out_dim]
                # Just ensure both have the same sequence length (should be same, but double-check)
                min_seq_len = min(noise_pred_cond.shape[1], noise_pred_uncond.shape[1])
                noise_pred_cond = noise_pred_cond[:, :min_seq_len, :]
                noise_pred_uncond = noise_pred_uncond[:, :min_seq_len, :]

            comb_pred = noise_pred_uncond + self.scheduler.sample_guide_scale * (noise_pred_cond - noise_pred_uncond)
            noise_pred_cond_norm = torch.norm(noise_pred_cond, dim=-1, keepdim=True)
            noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True)
            self.scheduler.noise_pred = comb_pred * (noise_pred_cond_norm / noise_norm)
        else:
            # ==================== No CFG Processing ====================
            noise_pred = self._infer_cond_uncond(latents_input, inputs["text_encoder_output"]["prompt_embeds"], infer_condition=True)

            # post_infer already extracts image part, so noise_pred is already [B, T_img, out_dim]
            # No need to extract again

            self.scheduler.noise_pred = noise_pred

        if self.cpu_offload:
            if self.offload_granularity == "model" and self.scheduler.step_index == self.scheduler.infer_steps - 1 and "wan2.2_moe" not in self.config["model_cls"]:
                self.to_cpu()
            elif self.offload_granularity != "model":
                self.pre_weight.to_cpu()
                self.post_weight.to_cpu()
                self.transformer_weights.non_block_weights_to_cpu()