model.py 16.4 KB
Newer Older
1
import gc
2
import glob
3
4
5
6
import json
import os

import torch
Watebear's avatar
Watebear committed
7
from safetensors import safe_open
8

Watebear's avatar
Watebear committed
9
10
from lightx2v.utils.envs import *
from lightx2v.utils.utils import *
11

12
from .infer.offload.transformer_infer import QwenImageOffloadTransformerInfer
13
14
15
from .infer.post_infer import QwenImagePostInfer
from .infer.pre_infer import QwenImagePreInfer
from .infer.transformer_infer import QwenImageTransformerInfer
Watebear's avatar
Watebear committed
16
17
18
from .weights.post_weights import QwenImagePostWeights
from .weights.pre_weights import QwenImagePreWeights
from .weights.transformer_weights import QwenImageTransformerWeights
19
20
21


class QwenImageTransformerModel:
Watebear's avatar
Watebear committed
22
23
24
25
    pre_weight_class = QwenImagePreWeights
    transformer_weight_class = QwenImageTransformerWeights
    post_weight_class = QwenImagePostWeights

26
27
    def __init__(self, config):
        self.config = config
28
        self.model_path = os.path.join(config["model_path"], "transformer")
29
        self.cpu_offload = config.get("cpu_offload", False)
Watebear's avatar
Watebear committed
30
        self.offload_granularity = self.config.get("offload_granularity", "block")
Kane's avatar
Kane committed
31
        self.device = torch.device("cpu") if self.cpu_offload else torch.device(self.config.get("run_device", "cuda"))
32

33
        with open(os.path.join(config["model_path"], "transformer", "config.json"), "r") as f:
34
35
36
37
            transformer_config = json.load(f)
            self.in_channels = transformer_config["in_channels"]
        self.attention_kwargs = {}

38
        self.dit_quantized = self.config.get("dit_quantized", False)
Watebear's avatar
Watebear committed
39

40
        self._init_infer_class()
Watebear's avatar
Watebear committed
41
        self._init_weights()
42
43
44
45
        self._init_infer()

    def set_scheduler(self, scheduler):
        self.scheduler = scheduler
Watebear's avatar
Watebear committed
46
47
48
        self.pre_infer.set_scheduler(scheduler)
        self.transformer_infer.set_scheduler(scheduler)
        self.post_infer.set_scheduler(scheduler)
49
50
51

    def _init_infer_class(self):
        if self.config["feature_caching"] == "NoCaching":
52
            self.transformer_infer_class = QwenImageTransformerInfer if not self.cpu_offload else QwenImageOffloadTransformerInfer
53
54
55
56
57
        else:
            assert NotImplementedError
        self.pre_infer_class = QwenImagePreInfer
        self.post_infer_class = QwenImagePostInfer

Watebear's avatar
Watebear committed
58
59
60
61
62
63
64
65
    def _init_weights(self, weight_dict=None):
        unified_dtype = GET_DTYPE() == GET_SENSITIVE_DTYPE()
        # Some layers run with float32 to achieve high accuracy
        sensitive_layer = {}

        if weight_dict is None:
            is_weight_loader = self._should_load_weights()
            if is_weight_loader:
66
                if not self.dit_quantized:
Watebear's avatar
Watebear committed
67
68
69
70
                    # Load original weights
                    weight_dict = self._load_ckpt(unified_dtype, sensitive_layer)
                else:
                    # Load quantized weights
71
72
73
74
                    if not self.config.get("lazy_load", False):
                        weight_dict = self._load_quant_ckpt(unified_dtype, sensitive_layer)
                    else:
                        weight_dict = self._load_quant_split_ckpt(unified_dtype, sensitive_layer)
Watebear's avatar
Watebear committed
75

76
77
            if self.config.get("device_mesh") is not None and self.config.get("load_from_rank0", False):
                weight_dict = self._load_weights_from_rank0(weight_dict, is_weight_loader)
Watebear's avatar
Watebear committed
78
79
80
81
82
83
84
85
86

            self.original_weight_dict = weight_dict
        else:
            self.original_weight_dict = weight_dict

        # Initialize weight containers
        self.pre_weight = self.pre_weight_class(self.config)
        self.transformer_weights = self.transformer_weight_class(self.config)
        self.post_weight = self.post_weight_class(self.config)
87
88
89
90
91
92
93
94
        if not self._should_init_empty_model():
            self._apply_weights()

    def _apply_weights(self, weight_dict=None):
        if weight_dict is not None:
            self.original_weight_dict = weight_dict
            del weight_dict
            gc.collect()
Watebear's avatar
Watebear committed
95
96
97
98
99
        # Load weights into containers
        self.pre_weight.load(self.original_weight_dict)
        self.transformer_weights.load(self.original_weight_dict)
        self.post_weight.load(self.original_weight_dict)

100
101
102
103
        del self.original_weight_dict
        torch.cuda.empty_cache()
        gc.collect()

Watebear's avatar
Watebear committed
104
105
106
107
108
109
    def _should_load_weights(self):
        """Determine if current rank should load weights from disk."""
        if self.config.get("device_mesh") is None:
            # Single GPU mode
            return True
        elif dist.is_initialized():
110
111
112
113
114
115
            if self.config.get("load_from_rank0", False):
                # Multi-GPU mode, only rank 0 loads
                if dist.get_rank() == 0:
                    logger.info(f"Loading weights from {self.model_path}")
                    return True
            else:
Watebear's avatar
Watebear committed
116
117
118
                return True
        return False

119
120
121
122
123
    def _should_init_empty_model(self):
        if self.config.get("lora_configs") and self.config["lora_configs"]:
            return True
        return False

Watebear's avatar
Watebear committed
124
    def _load_safetensor_to_dict(self, file_path, unified_dtype, sensitive_layer):
125
126
        remove_keys = self.remove_keys if hasattr(self, "remove_keys") else []

Kane's avatar
Kane committed
127
128
        if self.device.type in ["cuda", "mlu", "npu"] and dist.is_initialized():
            device = torch.device("{}:{}".format(self.device.type, dist.get_rank()))
129
130
131
132
133
134
135
136
137
        else:
            device = self.device

        with safe_open(file_path, framework="pt", device=str(device)) as f:
            return {
                key: (f.get_tensor(key).to(GET_DTYPE()) if unified_dtype or all(s not in key for s in sensitive_layer) else f.get_tensor(key).to(GET_SENSITIVE_DTYPE()))
                for key in f.keys()
                if not any(remove_key in key for remove_key in remove_keys)
            }
Watebear's avatar
Watebear committed
138
139

    def _load_ckpt(self, unified_dtype, sensitive_layer):
140
141
142
143
144
145
146
147
148
149
        if self.config.get("dit_original_ckpt", None):
            safetensors_path = self.config["dit_original_ckpt"]
        else:
            safetensors_path = self.model_path

        if os.path.isdir(safetensors_path):
            safetensors_files = glob.glob(os.path.join(safetensors_path, "*.safetensors"))
        else:
            safetensors_files = [safetensors_path]

Watebear's avatar
Watebear committed
150
151
        weight_dict = {}
        for file_path in safetensors_files:
152
            logger.info(f"Loading weights from {file_path}")
Watebear's avatar
Watebear committed
153
154
            file_weights = self._load_safetensor_to_dict(file_path, unified_dtype, sensitive_layer)
            weight_dict.update(file_weights)
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197

        return weight_dict

    def _load_quant_ckpt(self, unified_dtype, sensitive_layer):
        remove_keys = self.remove_keys if hasattr(self, "remove_keys") else []

        if self.config.get("dit_quantized_ckpt", None):
            safetensors_path = self.config["dit_quantized_ckpt"]
        else:
            safetensors_path = self.model_path

        if os.path.isdir(safetensors_path):
            safetensors_files = glob.glob(os.path.join(safetensors_path, "*.safetensors"))
        else:
            safetensors_files = [safetensors_path]
            safetensors_path = os.path.dirname(safetensors_path)

        weight_dict = {}
        for safetensor_path in safetensors_files:
            with safe_open(safetensor_path, framework="pt") as f:
                logger.info(f"Loading weights from {safetensor_path}")
                for k in f.keys():
                    if any(remove_key in k for remove_key in remove_keys):
                        continue
                    if f.get_tensor(k).dtype in [
                        torch.float16,
                        torch.bfloat16,
                        torch.float,
                    ]:
                        if unified_dtype or all(s not in k for s in sensitive_layer):
                            weight_dict[k] = f.get_tensor(k).to(GET_DTYPE()).to(self.device)
                        else:
                            weight_dict[k] = f.get_tensor(k).to(GET_SENSITIVE_DTYPE()).to(self.device)
                    else:
                        weight_dict[k] = f.get_tensor(k).to(self.device)

        if self.config.get("dit_quant_scheme", "Default") == "nvfp4":
            calib_path = os.path.join(safetensors_path, "calib.pt")
            logger.info(f"[CALIB] Loaded calibration data from: {calib_path}")
            calib_data = torch.load(calib_path, map_location="cpu")
            for k, v in calib_data["absmax"].items():
                weight_dict[k.replace(".weight", ".input_absmax")] = v.to(self.device)

Watebear's avatar
Watebear committed
198
199
        return weight_dict

200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
    def _load_quant_split_ckpt(self, unified_dtype, sensitive_layer):  # Need rewrite
        lazy_load_model_path = self.dit_quantized_ckpt
        logger.info(f"Loading splited quant model from {lazy_load_model_path}")
        pre_post_weight_dict = {}

        safetensor_path = os.path.join(lazy_load_model_path, "non_block.safetensors")
        with safe_open(safetensor_path, framework="pt", device="cpu") as f:
            for k in f.keys():
                if f.get_tensor(k).dtype in [
                    torch.float16,
                    torch.bfloat16,
                    torch.float,
                ]:
                    if unified_dtype or all(s not in k for s in sensitive_layer):
                        pre_post_weight_dict[k] = f.get_tensor(k).to(GET_DTYPE()).to(self.device)
                    else:
                        pre_post_weight_dict[k] = f.get_tensor(k).to(GET_SENSITIVE_DTYPE()).to(self.device)
                else:
                    pre_post_weight_dict[k] = f.get_tensor(k).to(self.device)

        return pre_post_weight_dict

    def _load_weights_from_rank0(self, weight_dict, is_weight_loader):
        logger.info("Loading distributed weights")
Watebear's avatar
Watebear committed
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
        global_src_rank = 0
        target_device = "cpu" if self.cpu_offload else "cuda"

        if is_weight_loader:
            meta_dict = {}
            for key, tensor in weight_dict.items():
                meta_dict[key] = {"shape": tensor.shape, "dtype": tensor.dtype}

            obj_list = [meta_dict]
            dist.broadcast_object_list(obj_list, src=global_src_rank)
            synced_meta_dict = obj_list[0]
        else:
            obj_list = [None]
            dist.broadcast_object_list(obj_list, src=global_src_rank)
            synced_meta_dict = obj_list[0]

        distributed_weight_dict = {}
        for key, meta in synced_meta_dict.items():
            distributed_weight_dict[key] = torch.empty(meta["shape"], dtype=meta["dtype"], device=target_device)

        if target_device == "cuda":
            dist.barrier(device_ids=[torch.cuda.current_device()])

        for key in sorted(synced_meta_dict.keys()):
            if is_weight_loader:
                distributed_weight_dict[key].copy_(weight_dict[key], non_blocking=True)

            if target_device == "cpu":
                if is_weight_loader:
                    gpu_tensor = distributed_weight_dict[key].cuda()
                    dist.broadcast(gpu_tensor, src=global_src_rank)
                    distributed_weight_dict[key].copy_(gpu_tensor.cpu(), non_blocking=True)
                    del gpu_tensor
                    torch.cuda.empty_cache()
                else:
                    gpu_tensor = torch.empty_like(distributed_weight_dict[key], device="cuda")
                    dist.broadcast(gpu_tensor, src=global_src_rank)
                    distributed_weight_dict[key].copy_(gpu_tensor.cpu(), non_blocking=True)
                    del gpu_tensor
                    torch.cuda.empty_cache()

                if distributed_weight_dict[key].is_pinned():
                    distributed_weight_dict[key].copy_(distributed_weight_dict[key], non_blocking=True)
            else:
                dist.broadcast(distributed_weight_dict[key], src=global_src_rank)

        if target_device == "cuda":
            torch.cuda.synchronize()
        else:
            for tensor in distributed_weight_dict.values():
                if tensor.is_pinned():
                    tensor.copy_(tensor, non_blocking=False)

        logger.info(f"Weights distributed across {dist.get_world_size()} devices on {target_device}")
278

Watebear's avatar
Watebear committed
279
280
        return distributed_weight_dict

281
    def _init_infer(self):
Watebear's avatar
Watebear committed
282
283
284
        self.transformer_infer = self.transformer_infer_class(self.config)
        self.pre_infer = self.pre_infer_class(self.config)
        self.post_infer = self.post_infer_class(self.config)
285
286
        if hasattr(self.transformer_infer, "offload_manager"):
            self.transformer_infer.offload_manager.init_cuda_buffer(self.transformer_weights.offload_block_buffers, self.transformer_weights.offload_phase_buffers)
Watebear's avatar
Watebear committed
287
288
289
290
291
292
293
294
295
296

    def to_cpu(self):
        self.pre_weight.to_cpu()
        self.transformer_weights.to_cpu()
        self.post_weight.to_cpu()

    def to_cuda(self):
        self.pre_weight.to_cuda()
        self.transformer_weights.to_cuda()
        self.post_weight.to_cuda()
297
298
299

    @torch.no_grad()
    def infer(self, inputs):
Watebear's avatar
Watebear committed
300
301
302
303
304
305
306
        if self.cpu_offload:
            if self.offload_granularity == "model" and self.scheduler.step_index == 0:
                self.to_cuda()
            elif self.offload_granularity != "model":
                self.pre_weight.to_cuda()
                self.post_weight.to_cuda()

307
308
        t = self.scheduler.timesteps[self.scheduler.step_index]
        latents = self.scheduler.latents
309
        if self.config["task"] == "i2i":
310
            image_latents = torch.cat([item["image_latents"] for item in inputs["image_encoder_output"]], dim=1)
311
312
313
314
            latents_input = torch.cat([latents, image_latents], dim=1)
        else:
            latents_input = latents

315
        timestep = t.expand(latents.shape[0]).to(latents.dtype)
316
        img_shapes = inputs["img_shapes"]
317
318
319
320
321

        prompt_embeds = inputs["text_encoder_output"]["prompt_embeds"]
        prompt_embeds_mask = inputs["text_encoder_output"]["prompt_embeds_mask"]

        txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist() if prompt_embeds_mask is not None else None
Watebear's avatar
Watebear committed
322
323
324

        hidden_states, encoder_hidden_states, _, pre_infer_out = self.pre_infer.infer(
            weights=self.pre_weight,
325
            hidden_states=latents_input,
326
327
328
329
330
331
332
333
334
335
            timestep=timestep / 1000,
            guidance=self.scheduler.guidance,
            encoder_hidden_states_mask=prompt_embeds_mask,
            encoder_hidden_states=prompt_embeds,
            img_shapes=img_shapes,
            txt_seq_lens=txt_seq_lens,
            attention_kwargs=self.attention_kwargs,
        )

        encoder_hidden_states, hidden_states = self.transformer_infer.infer(
Watebear's avatar
Watebear committed
336
337
338
            block_weights=self.transformer_weights,
            hidden_states=hidden_states.unsqueeze(0),
            encoder_hidden_states=encoder_hidden_states.unsqueeze(0),
339
340
341
            pre_infer_out=pre_infer_out,
        )

Watebear's avatar
Watebear committed
342
343
        noise_pred = self.post_infer.infer(self.post_weight, hidden_states, pre_infer_out[0])

344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
        if self.config["do_true_cfg"]:
            neg_prompt_embeds = inputs["text_encoder_output"]["negative_prompt_embeds"]
            neg_prompt_embeds_mask = inputs["text_encoder_output"]["negative_prompt_embeds_mask"]

            negative_txt_seq_lens = neg_prompt_embeds_mask.sum(dim=1).tolist() if neg_prompt_embeds_mask is not None else None

            neg_hidden_states, neg_encoder_hidden_states, _, neg_pre_infer_out = self.pre_infer.infer(
                weights=self.pre_weight,
                hidden_states=latents_input,
                timestep=timestep / 1000,
                guidance=self.scheduler.guidance,
                encoder_hidden_states_mask=neg_prompt_embeds_mask,
                encoder_hidden_states=neg_prompt_embeds,
                img_shapes=img_shapes,
                txt_seq_lens=negative_txt_seq_lens,
                attention_kwargs=self.attention_kwargs,
            )

            neg_encoder_hidden_states, neg_hidden_states = self.transformer_infer.infer(
                block_weights=self.transformer_weights,
                hidden_states=neg_hidden_states.unsqueeze(0),
                encoder_hidden_states=neg_encoder_hidden_states.unsqueeze(0),
                pre_infer_out=neg_pre_infer_out,
            )

            neg_noise_pred = self.post_infer.infer(self.post_weight, neg_hidden_states, neg_pre_infer_out[0])

        if self.config["task"] == "i2i":
372
            noise_pred = noise_pred[:, : latents.size(1)]
373

374
375
376
377
378
379
380
381
382
        if self.config["do_true_cfg"]:
            neg_noise_pred = neg_noise_pred[:, : latents.size(1)]
            comb_pred = neg_noise_pred + self.config["true_cfg_scale"] * (noise_pred - neg_noise_pred)

            cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True)
            noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True)
            noise_pred = comb_pred * (cond_norm / noise_norm)

        noise_pred = noise_pred[:, : latents.size(1)]
383
        self.scheduler.noise_pred = noise_pred