model.py 6.59 KB
Newer Older
litzh's avatar
litzh committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import torch
import torch.distributed as dist

from lightx2v.models.networks.base_model import BaseTransformerModel
from lightx2v.models.networks.longcat_image.infer.post_infer import LongCatImagePostInfer
from lightx2v.models.networks.longcat_image.infer.pre_infer import LongCatImagePreInfer
from lightx2v.models.networks.longcat_image.infer.transformer_infer import LongCatImageTransformerInfer
from lightx2v.models.networks.longcat_image.weights.post_weights import LongCatImagePostWeights
from lightx2v.models.networks.longcat_image.weights.pre_weights import LongCatImagePreWeights
from lightx2v.models.networks.longcat_image.weights.transformer_weights import LongCatImageTransformerWeights
from lightx2v.utils.custom_compiler import compiled_method
from lightx2v.utils.envs import *


class LongCatImageTransformerModel(BaseTransformerModel):
    """Transformer model for LongCat Image.

    Handles weight loading and inference for the LongCat architecture
    (10 double-stream blocks + 20 single-stream blocks).
    """

    pre_weight_class = LongCatImagePreWeights
    transformer_weight_class = LongCatImageTransformerWeights
    post_weight_class = LongCatImagePostWeights

    def __init__(self, config, model_path, device):
        super().__init__(model_path, config, device)
        # Use transformer_in_channels to avoid conflict with VAE's in_channels
        self.in_channels = self.config.get("transformer_in_channels", self.config.get("in_channels", 64))
        self.attention_kwargs = {}
        if self.config["seq_parallel"]:
            raise NotImplementedError("Sequence parallel is not implemented for LongCatImageTransformerModel")
        self._init_infer_class()
        self._init_weights()
        self._init_infer()

    def _init_infer_class(self):
        self.transformer_infer_class = LongCatImageTransformerInfer
        self.pre_infer_class = LongCatImagePreInfer
        self.post_infer_class = LongCatImagePostInfer

    def _init_infer(self):
        self.transformer_infer = self.transformer_infer_class(self.config)
        self.pre_infer = self.pre_infer_class(self.config)
        self.post_infer = self.post_infer_class(self.config)
        if hasattr(self.transformer_infer, "offload_manager"):
            self._init_offload_manager()

    @torch.no_grad()
    def _infer_cond_uncond(self, latents_input, prompt_embeds, infer_condition=True):
        self.scheduler.infer_condition = infer_condition

        pre_infer_out = self.pre_infer.infer(
            weights=self.pre_weight,
            hidden_states=latents_input,
            encoder_hidden_states=prompt_embeds,
        )

        hidden_states = self.transformer_infer.infer(
            block_weights=self.transformer_weights,
            pre_infer_out=pre_infer_out,
        )

        noise_pred = self.post_infer.infer(self.post_weight, hidden_states, pre_infer_out.temb)

        return noise_pred

    @torch.no_grad()
    def _seq_parallel_pre_process(self, pre_infer_out):
        raise NotImplementedError("Sequence parallel pre-process is not implemented for LongCatImageTransformerModel")

    @torch.no_grad()
    def _seq_parallel_post_process(self, x):
        raise NotImplementedError("Sequence parallel post-process is not implemented for LongCatImageTransformerModel")

    @compiled_method()
    @torch.no_grad()
    def infer(self, inputs):
        if self.cpu_offload:
            self.to_cuda()

        latents = self.scheduler.latents

        if self.config.get("enable_cfg", True):
            # Check if CFG parallel should be used
            # Note: I2I task may have different sequence lengths for positive/negative prompts,
            # which is not yet supported in CFG parallel mode
            use_cfg_parallel = self.config.get("cfg_parallel", False)
            if use_cfg_parallel and hasattr(self.scheduler, "input_image_latents") and self.scheduler.input_image_latents is not None:
                # I2I task: check if sequence lengths match
                if hasattr(self.scheduler, "image_rotary_emb") and hasattr(self.scheduler, "negative_image_rotary_emb"):
                    pos_len = self.scheduler.image_rotary_emb[0].shape[0]
                    neg_len = self.scheduler.negative_image_rotary_emb[0].shape[0]
                    if pos_len != neg_len:
                        from lightx2v.utils.utils import logger

                        if dist.get_rank() == 0:
                            logger.warning(f"CFG parallel disabled for I2I task due to sequence length mismatch (positive: {pos_len}, negative: {neg_len}). Falling back to sequential CFG.")
                        use_cfg_parallel = False

            if use_cfg_parallel:
                # ==================== CFG Parallel Processing ====================
                cfg_p_group = self.config["device_mesh"].get_group(mesh_dim="cfg_p")
                assert dist.get_world_size(cfg_p_group) == 2, "cfg_p_world_size must be equal to 2"
                cfg_p_rank = dist.get_rank(cfg_p_group)

                if cfg_p_rank == 0:
                    noise_pred = self._infer_cond_uncond(latents, inputs["text_encoder_output"]["prompt_embeds"], infer_condition=True)
                else:
                    noise_pred = self._infer_cond_uncond(latents, inputs["text_encoder_output"]["negative_prompt_embeds"], infer_condition=False)

                noise_pred_list = [torch.zeros_like(noise_pred) for _ in range(2)]
                dist.all_gather(noise_pred_list, noise_pred, group=cfg_p_group)
                noise_pred_cond = noise_pred_list[0]  # cfg_p_rank == 0
                noise_pred_uncond = noise_pred_list[1]  # cfg_p_rank == 1

                # Apply CFG with optional renormalization
                noise_pred = self.scheduler.apply_cfg(noise_pred_cond, noise_pred_uncond)
                self.scheduler.noise_pred = noise_pred
            else:
                # ==================== CFG Sequential Processing ====================
                noise_pred_cond = self._infer_cond_uncond(latents, inputs["text_encoder_output"]["prompt_embeds"], infer_condition=True)
                noise_pred_uncond = self._infer_cond_uncond(latents, inputs["text_encoder_output"]["negative_prompt_embeds"], infer_condition=False)

                # Apply CFG with optional renormalization
                noise_pred = self.scheduler.apply_cfg(noise_pred_cond, noise_pred_uncond)
                self.scheduler.noise_pred = noise_pred
        else:
            # ==================== No CFG Processing ====================
            noise_pred = self._infer_cond_uncond(latents, inputs["text_encoder_output"]["prompt_embeds"], infer_condition=True)
            self.scheduler.noise_pred = noise_pred