model.py 9.71 KB
Newer Older
1
2
3
4
import json
import os

import torch
Watebear's avatar
Watebear committed
5
from safetensors import safe_open
6

Watebear's avatar
Watebear committed
7
8
from lightx2v.utils.envs import *
from lightx2v.utils.utils import *
9

10
from .infer.offload.transformer_infer import QwenImageOffloadTransformerInfer
11
12
13
from .infer.post_infer import QwenImagePostInfer
from .infer.pre_infer import QwenImagePreInfer
from .infer.transformer_infer import QwenImageTransformerInfer
Watebear's avatar
Watebear committed
14
15
16
from .weights.post_weights import QwenImagePostWeights
from .weights.pre_weights import QwenImagePreWeights
from .weights.transformer_weights import QwenImageTransformerWeights
17
18
19


class QwenImageTransformerModel:
Watebear's avatar
Watebear committed
20
21
22
23
    pre_weight_class = QwenImagePreWeights
    transformer_weight_class = QwenImageTransformerWeights
    post_weight_class = QwenImagePostWeights

24
25
    def __init__(self, config):
        self.config = config
Watebear's avatar
Watebear committed
26
        self.model_path = os.path.join(config.model_path, "transformer")
27
        self.cpu_offload = config.get("cpu_offload", False)
Watebear's avatar
Watebear committed
28
29
        self.offload_granularity = self.config.get("offload_granularity", "block")
        self.device = torch.device("cpu") if self.cpu_offload else torch.device("cuda")
30
31
32
33
34
35

        with open(os.path.join(config.model_path, "transformer", "config.json"), "r") as f:
            transformer_config = json.load(f)
            self.in_channels = transformer_config["in_channels"]
        self.attention_kwargs = {}

Watebear's avatar
Watebear committed
36
37
38
        self.dit_quantized = self.config.mm_config.get("mm_type", "Default") != "Default"
        self.weight_auto_quant = self.config.mm_config.get("weight_auto_quant", False)

39
        self._init_infer_class()
Watebear's avatar
Watebear committed
40
        self._init_weights()
41
42
43
44
        self._init_infer()

    def set_scheduler(self, scheduler):
        self.scheduler = scheduler
Watebear's avatar
Watebear committed
45
46
47
        self.pre_infer.set_scheduler(scheduler)
        self.transformer_infer.set_scheduler(scheduler)
        self.post_infer.set_scheduler(scheduler)
48
49
50

    def _init_infer_class(self):
        if self.config["feature_caching"] == "NoCaching":
51
            self.transformer_infer_class = QwenImageTransformerInfer if not self.cpu_offload else QwenImageOffloadTransformerInfer
52
53
54
55
56
        else:
            assert NotImplementedError
        self.pre_infer_class = QwenImagePreInfer
        self.post_infer_class = QwenImagePostInfer

Watebear's avatar
Watebear committed
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
    def _init_weights(self, weight_dict=None):
        unified_dtype = GET_DTYPE() == GET_SENSITIVE_DTYPE()
        # Some layers run with float32 to achieve high accuracy
        sensitive_layer = {}

        if weight_dict is None:
            is_weight_loader = self._should_load_weights()
            if is_weight_loader:
                if not self.dit_quantized or self.weight_auto_quant:
                    # Load original weights
                    weight_dict = self._load_ckpt(unified_dtype, sensitive_layer)
                else:
                    # Load quantized weights
                    assert NotImplementedError

            if self.config.get("device_mesh") is not None:
                weight_dict = self._load_weights_distribute(weight_dict, is_weight_loader)

            self.original_weight_dict = weight_dict
        else:
            self.original_weight_dict = weight_dict

        # Initialize weight containers
        self.pre_weight = self.pre_weight_class(self.config)
        self.transformer_weights = self.transformer_weight_class(self.config)
        self.post_weight = self.post_weight_class(self.config)
        # Load weights into containers
        self.pre_weight.load(self.original_weight_dict)
        self.transformer_weights.load(self.original_weight_dict)
        self.post_weight.load(self.original_weight_dict)

    def _should_load_weights(self):
        """Determine if current rank should load weights from disk."""
        if self.config.get("device_mesh") is None:
            # Single GPU mode
            return True
        elif dist.is_initialized():
            # Multi-GPU mode, only rank 0 loads
            if dist.get_rank() == 0:
                logger.info(f"Loading weights from {self.model_path}")
                return True
        return False

    def _load_safetensor_to_dict(self, file_path, unified_dtype, sensitive_layer):
        with safe_open(file_path, framework="pt") as f:
            return {
                key: (f.get_tensor(key).to(GET_DTYPE()) if unified_dtype or all(s not in key for s in sensitive_layer) else f.get_tensor(key).to(GET_SENSITIVE_DTYPE())).pin_memory().to(self.device)
                for key in f.keys()
            }

    def _load_ckpt(self, unified_dtype, sensitive_layer):
        safetensors_files = glob.glob(os.path.join(self.model_path, "*.safetensors"))
        weight_dict = {}
        for file_path in safetensors_files:
            file_weights = self._load_safetensor_to_dict(file_path, unified_dtype, sensitive_layer)
            weight_dict.update(file_weights)
        return weight_dict

    def _load_weights_distribute(self, weight_dict, is_weight_loader):
        global_src_rank = 0
        target_device = "cpu" if self.cpu_offload else "cuda"

        if is_weight_loader:
            meta_dict = {}
            for key, tensor in weight_dict.items():
                meta_dict[key] = {"shape": tensor.shape, "dtype": tensor.dtype}

            obj_list = [meta_dict]
            dist.broadcast_object_list(obj_list, src=global_src_rank)
            synced_meta_dict = obj_list[0]
        else:
            obj_list = [None]
            dist.broadcast_object_list(obj_list, src=global_src_rank)
            synced_meta_dict = obj_list[0]

        distributed_weight_dict = {}
        for key, meta in synced_meta_dict.items():
            distributed_weight_dict[key] = torch.empty(meta["shape"], dtype=meta["dtype"], device=target_device)

        if target_device == "cuda":
            dist.barrier(device_ids=[torch.cuda.current_device()])

        for key in sorted(synced_meta_dict.keys()):
            if is_weight_loader:
                distributed_weight_dict[key].copy_(weight_dict[key], non_blocking=True)

            if target_device == "cpu":
                if is_weight_loader:
                    gpu_tensor = distributed_weight_dict[key].cuda()
                    dist.broadcast(gpu_tensor, src=global_src_rank)
                    distributed_weight_dict[key].copy_(gpu_tensor.cpu(), non_blocking=True)
                    del gpu_tensor
                    torch.cuda.empty_cache()
                else:
                    gpu_tensor = torch.empty_like(distributed_weight_dict[key], device="cuda")
                    dist.broadcast(gpu_tensor, src=global_src_rank)
                    distributed_weight_dict[key].copy_(gpu_tensor.cpu(), non_blocking=True)
                    del gpu_tensor
                    torch.cuda.empty_cache()

                if distributed_weight_dict[key].is_pinned():
                    distributed_weight_dict[key].copy_(distributed_weight_dict[key], non_blocking=True)
            else:
                dist.broadcast(distributed_weight_dict[key], src=global_src_rank)

        if target_device == "cuda":
            torch.cuda.synchronize()
        else:
            for tensor in distributed_weight_dict.values():
                if tensor.is_pinned():
                    tensor.copy_(tensor, non_blocking=False)

        logger.info(f"Weights distributed across {dist.get_world_size()} devices on {target_device}")
        return distributed_weight_dict

172
    def _init_infer(self):
Watebear's avatar
Watebear committed
173
174
175
176
177
178
179
180
181
182
183
184
185
        self.transformer_infer = self.transformer_infer_class(self.config)
        self.pre_infer = self.pre_infer_class(self.config)
        self.post_infer = self.post_infer_class(self.config)

    def to_cpu(self):
        self.pre_weight.to_cpu()
        self.transformer_weights.to_cpu()
        self.post_weight.to_cpu()

    def to_cuda(self):
        self.pre_weight.to_cuda()
        self.transformer_weights.to_cuda()
        self.post_weight.to_cuda()
186
187
188

    @torch.no_grad()
    def infer(self, inputs):
Watebear's avatar
Watebear committed
189
190
191
192
193
194
195
        if self.cpu_offload:
            if self.offload_granularity == "model" and self.scheduler.step_index == 0:
                self.to_cuda()
            elif self.offload_granularity != "model":
                self.pre_weight.to_cuda()
                self.post_weight.to_cuda()

196
197
        t = self.scheduler.timesteps[self.scheduler.step_index]
        latents = self.scheduler.latents
198
199
200
201
202
203
        if self.config.task == "i2i":
            image_latents = inputs["image_encoder_output"]["image_latents"]
            latents_input = torch.cat([latents, image_latents], dim=1)
        else:
            latents_input = latents

204
205
206
207
208
209
210
        timestep = t.expand(latents.shape[0]).to(latents.dtype)
        img_shapes = self.scheduler.img_shapes

        prompt_embeds = inputs["text_encoder_output"]["prompt_embeds"]
        prompt_embeds_mask = inputs["text_encoder_output"]["prompt_embeds_mask"]

        txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist() if prompt_embeds_mask is not None else None
Watebear's avatar
Watebear committed
211
212
213

        hidden_states, encoder_hidden_states, _, pre_infer_out = self.pre_infer.infer(
            weights=self.pre_weight,
214
            hidden_states=latents_input,
215
216
217
218
219
220
221
222
223
224
            timestep=timestep / 1000,
            guidance=self.scheduler.guidance,
            encoder_hidden_states_mask=prompt_embeds_mask,
            encoder_hidden_states=prompt_embeds,
            img_shapes=img_shapes,
            txt_seq_lens=txt_seq_lens,
            attention_kwargs=self.attention_kwargs,
        )

        encoder_hidden_states, hidden_states = self.transformer_infer.infer(
Watebear's avatar
Watebear committed
225
226
227
            block_weights=self.transformer_weights,
            hidden_states=hidden_states.unsqueeze(0),
            encoder_hidden_states=encoder_hidden_states.unsqueeze(0),
228
229
230
            pre_infer_out=pre_infer_out,
        )

Watebear's avatar
Watebear committed
231
232
        noise_pred = self.post_infer.infer(self.post_weight, hidden_states, pre_infer_out[0])

233
234
        if self.config.task == "i2i":
            noise_pred = noise_pred[:, : latents.size(1)]
235
236

        self.scheduler.noise_pred = noise_pred