model.py 11.5 KB
Newer Older
1
import glob
2
3
4
5
import json
import os

import torch
Watebear's avatar
Watebear committed
6
from safetensors import safe_open
7

Watebear's avatar
Watebear committed
8
9
from lightx2v.utils.envs import *
from lightx2v.utils.utils import *
10

11
from .infer.offload.transformer_infer import QwenImageOffloadTransformerInfer
12
13
14
from .infer.post_infer import QwenImagePostInfer
from .infer.pre_infer import QwenImagePreInfer
from .infer.transformer_infer import QwenImageTransformerInfer
Watebear's avatar
Watebear committed
15
16
17
from .weights.post_weights import QwenImagePostWeights
from .weights.pre_weights import QwenImagePreWeights
from .weights.transformer_weights import QwenImageTransformerWeights
18
19
20


class QwenImageTransformerModel:
Watebear's avatar
Watebear committed
21
22
23
24
    pre_weight_class = QwenImagePreWeights
    transformer_weight_class = QwenImageTransformerWeights
    post_weight_class = QwenImagePostWeights

25
26
    def __init__(self, config):
        self.config = config
27
        self.model_path = os.path.join(config["model_path"], "transformer")
28
        self.cpu_offload = config.get("cpu_offload", False)
Watebear's avatar
Watebear committed
29
30
        self.offload_granularity = self.config.get("offload_granularity", "block")
        self.device = torch.device("cpu") if self.cpu_offload else torch.device("cuda")
31

32
        with open(os.path.join(config["model_path"], "transformer", "config.json"), "r") as f:
33
34
35
36
            transformer_config = json.load(f)
            self.in_channels = transformer_config["in_channels"]
        self.attention_kwargs = {}

37
38
        self.dit_quantized = self.config["mm_config"].get("mm_type", "Default") != "Default"
        self.weight_auto_quant = self.config["mm_config"].get("weight_auto_quant", False)
Watebear's avatar
Watebear committed
39

40
        self._init_infer_class()
Watebear's avatar
Watebear committed
41
        self._init_weights()
42
43
44
45
        self._init_infer()

    def set_scheduler(self, scheduler):
        self.scheduler = scheduler
Watebear's avatar
Watebear committed
46
47
48
        self.pre_infer.set_scheduler(scheduler)
        self.transformer_infer.set_scheduler(scheduler)
        self.post_infer.set_scheduler(scheduler)
49
50
51

    def _init_infer_class(self):
        if self.config["feature_caching"] == "NoCaching":
52
            self.transformer_infer_class = QwenImageTransformerInfer if not self.cpu_offload else QwenImageOffloadTransformerInfer
53
54
55
56
57
        else:
            assert NotImplementedError
        self.pre_infer_class = QwenImagePreInfer
        self.post_infer_class = QwenImagePostInfer

Watebear's avatar
Watebear committed
58
59
60
61
62
63
64
65
    def _init_weights(self, weight_dict=None):
        unified_dtype = GET_DTYPE() == GET_SENSITIVE_DTYPE()
        # Some layers run with float32 to achieve high accuracy
        sensitive_layer = {}

        if weight_dict is None:
            is_weight_loader = self._should_load_weights()
            if is_weight_loader:
66
                if not self.dit_quantized or self.weight_auto_quant:
Watebear's avatar
Watebear committed
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
                    # Load original weights
                    weight_dict = self._load_ckpt(unified_dtype, sensitive_layer)
                else:
                    # Load quantized weights
                    assert NotImplementedError

            if self.config.get("device_mesh") is not None:
                weight_dict = self._load_weights_distribute(weight_dict, is_weight_loader)

            self.original_weight_dict = weight_dict
        else:
            self.original_weight_dict = weight_dict

        # Initialize weight containers
        self.pre_weight = self.pre_weight_class(self.config)
        self.transformer_weights = self.transformer_weight_class(self.config)
        self.post_weight = self.post_weight_class(self.config)
        # Load weights into containers
        self.pre_weight.load(self.original_weight_dict)
        self.transformer_weights.load(self.original_weight_dict)
        self.post_weight.load(self.original_weight_dict)

    def _should_load_weights(self):
        """Determine if current rank should load weights from disk."""
        if self.config.get("device_mesh") is None:
            # Single GPU mode
            return True
        elif dist.is_initialized():
            # Multi-GPU mode, only rank 0 loads
            if dist.get_rank() == 0:
                logger.info(f"Loading weights from {self.model_path}")
                return True
        return False

    def _load_safetensor_to_dict(self, file_path, unified_dtype, sensitive_layer):
gushiqiao's avatar
gushiqiao committed
102
103
        with safe_open(file_path, framework="pt", device=str(self.device)) as f:
            return {key: (f.get_tensor(key).to(GET_DTYPE()) if unified_dtype or all(s not in key for s in sensitive_layer) else f.get_tensor(key).to(GET_SENSITIVE_DTYPE())) for key in f.keys()}
Watebear's avatar
Watebear committed
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169

    def _load_ckpt(self, unified_dtype, sensitive_layer):
        safetensors_files = glob.glob(os.path.join(self.model_path, "*.safetensors"))
        weight_dict = {}
        for file_path in safetensors_files:
            file_weights = self._load_safetensor_to_dict(file_path, unified_dtype, sensitive_layer)
            weight_dict.update(file_weights)
        return weight_dict

    def _load_weights_distribute(self, weight_dict, is_weight_loader):
        global_src_rank = 0
        target_device = "cpu" if self.cpu_offload else "cuda"

        if is_weight_loader:
            meta_dict = {}
            for key, tensor in weight_dict.items():
                meta_dict[key] = {"shape": tensor.shape, "dtype": tensor.dtype}

            obj_list = [meta_dict]
            dist.broadcast_object_list(obj_list, src=global_src_rank)
            synced_meta_dict = obj_list[0]
        else:
            obj_list = [None]
            dist.broadcast_object_list(obj_list, src=global_src_rank)
            synced_meta_dict = obj_list[0]

        distributed_weight_dict = {}
        for key, meta in synced_meta_dict.items():
            distributed_weight_dict[key] = torch.empty(meta["shape"], dtype=meta["dtype"], device=target_device)

        if target_device == "cuda":
            dist.barrier(device_ids=[torch.cuda.current_device()])

        for key in sorted(synced_meta_dict.keys()):
            if is_weight_loader:
                distributed_weight_dict[key].copy_(weight_dict[key], non_blocking=True)

            if target_device == "cpu":
                if is_weight_loader:
                    gpu_tensor = distributed_weight_dict[key].cuda()
                    dist.broadcast(gpu_tensor, src=global_src_rank)
                    distributed_weight_dict[key].copy_(gpu_tensor.cpu(), non_blocking=True)
                    del gpu_tensor
                    torch.cuda.empty_cache()
                else:
                    gpu_tensor = torch.empty_like(distributed_weight_dict[key], device="cuda")
                    dist.broadcast(gpu_tensor, src=global_src_rank)
                    distributed_weight_dict[key].copy_(gpu_tensor.cpu(), non_blocking=True)
                    del gpu_tensor
                    torch.cuda.empty_cache()

                if distributed_weight_dict[key].is_pinned():
                    distributed_weight_dict[key].copy_(distributed_weight_dict[key], non_blocking=True)
            else:
                dist.broadcast(distributed_weight_dict[key], src=global_src_rank)

        if target_device == "cuda":
            torch.cuda.synchronize()
        else:
            for tensor in distributed_weight_dict.values():
                if tensor.is_pinned():
                    tensor.copy_(tensor, non_blocking=False)

        logger.info(f"Weights distributed across {dist.get_world_size()} devices on {target_device}")
        return distributed_weight_dict

170
    def _init_infer(self):
Watebear's avatar
Watebear committed
171
172
173
174
175
176
177
178
179
180
181
182
183
        self.transformer_infer = self.transformer_infer_class(self.config)
        self.pre_infer = self.pre_infer_class(self.config)
        self.post_infer = self.post_infer_class(self.config)

    def to_cpu(self):
        self.pre_weight.to_cpu()
        self.transformer_weights.to_cpu()
        self.post_weight.to_cpu()

    def to_cuda(self):
        self.pre_weight.to_cuda()
        self.transformer_weights.to_cuda()
        self.post_weight.to_cuda()
184
185
186

    @torch.no_grad()
    def infer(self, inputs):
Watebear's avatar
Watebear committed
187
188
189
190
191
192
193
        if self.cpu_offload:
            if self.offload_granularity == "model" and self.scheduler.step_index == 0:
                self.to_cuda()
            elif self.offload_granularity != "model":
                self.pre_weight.to_cuda()
                self.post_weight.to_cuda()

194
195
        t = self.scheduler.timesteps[self.scheduler.step_index]
        latents = self.scheduler.latents
196
        if self.config["task"] == "i2i":
197
198
199
200
201
            image_latents = inputs["image_encoder_output"]["image_latents"]
            latents_input = torch.cat([latents, image_latents], dim=1)
        else:
            latents_input = latents

202
        timestep = t.expand(latents.shape[0]).to(latents.dtype)
203
        img_shapes = inputs["img_shapes"]
204
205
206
207
208

        prompt_embeds = inputs["text_encoder_output"]["prompt_embeds"]
        prompt_embeds_mask = inputs["text_encoder_output"]["prompt_embeds_mask"]

        txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist() if prompt_embeds_mask is not None else None
Watebear's avatar
Watebear committed
209
210
211

        hidden_states, encoder_hidden_states, _, pre_infer_out = self.pre_infer.infer(
            weights=self.pre_weight,
212
            hidden_states=latents_input,
213
214
215
216
217
218
219
220
221
222
            timestep=timestep / 1000,
            guidance=self.scheduler.guidance,
            encoder_hidden_states_mask=prompt_embeds_mask,
            encoder_hidden_states=prompt_embeds,
            img_shapes=img_shapes,
            txt_seq_lens=txt_seq_lens,
            attention_kwargs=self.attention_kwargs,
        )

        encoder_hidden_states, hidden_states = self.transformer_infer.infer(
Watebear's avatar
Watebear committed
223
224
225
            block_weights=self.transformer_weights,
            hidden_states=hidden_states.unsqueeze(0),
            encoder_hidden_states=encoder_hidden_states.unsqueeze(0),
226
227
228
            pre_infer_out=pre_infer_out,
        )

Watebear's avatar
Watebear committed
229
230
        noise_pred = self.post_infer.infer(self.post_weight, hidden_states, pre_infer_out[0])

231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
        if self.config["do_true_cfg"]:
            neg_prompt_embeds = inputs["text_encoder_output"]["negative_prompt_embeds"]
            neg_prompt_embeds_mask = inputs["text_encoder_output"]["negative_prompt_embeds_mask"]

            negative_txt_seq_lens = neg_prompt_embeds_mask.sum(dim=1).tolist() if neg_prompt_embeds_mask is not None else None

            neg_hidden_states, neg_encoder_hidden_states, _, neg_pre_infer_out = self.pre_infer.infer(
                weights=self.pre_weight,
                hidden_states=latents_input,
                timestep=timestep / 1000,
                guidance=self.scheduler.guidance,
                encoder_hidden_states_mask=neg_prompt_embeds_mask,
                encoder_hidden_states=neg_prompt_embeds,
                img_shapes=img_shapes,
                txt_seq_lens=negative_txt_seq_lens,
                attention_kwargs=self.attention_kwargs,
            )

            neg_encoder_hidden_states, neg_hidden_states = self.transformer_infer.infer(
                block_weights=self.transformer_weights,
                hidden_states=neg_hidden_states.unsqueeze(0),
                encoder_hidden_states=neg_encoder_hidden_states.unsqueeze(0),
                pre_infer_out=neg_pre_infer_out,
            )

            neg_noise_pred = self.post_infer.infer(self.post_weight, neg_hidden_states, neg_pre_infer_out[0])

        if self.config["task"] == "i2i":
259
            noise_pred = noise_pred[:, : latents.size(1)]
260

261
262
263
264
265
266
267
268
269
        if self.config["do_true_cfg"]:
            neg_noise_pred = neg_noise_pred[:, : latents.size(1)]
            comb_pred = neg_noise_pred + self.config["true_cfg_scale"] * (noise_pred - neg_noise_pred)

            cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True)
            noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True)
            noise_pred = comb_pred * (cond_norm / noise_norm)

        noise_pred = noise_pred[:, : latents.size(1)]
270
        self.scheduler.noise_pred = noise_pred