model.py 11.8 KB
Newer Older
litzh's avatar
litzh committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
import gc
from pathlib import Path

import torch
from loguru import logger
from safetensors import safe_open
from transformers import AutoImageProcessor, Gemma3ForConditionalGeneration, Gemma3Processor

from lightx2v.models.input_encoders.hf.ltx2.gemma.encoders.av_encoder import (
    AV_GEMMA_TEXT_ENCODER_KEY_OPS,
    AVGemmaTextEncoderModel,
    AVGemmaTextEncoderModelConfigurator,
)
from lightx2v.models.input_encoders.hf.ltx2.gemma.tokenizer import LTXVGemmaTokenizer
from lightx2v.utils.envs import GET_DTYPE
from lightx2v.utils.lora_loader import LoRALoader
from lightx2v.utils.ltx2_utils import *
from lightx2v_platform.base.global_var import AI_DEVICE

torch_device_module = getattr(torch, AI_DEVICE)


def _find_matching_dir(root_path: str, pattern: str) -> str:
    """Recursively search for files matching a glob pattern and return the parent directory of the first match."""
    matches = list(Path(root_path).rglob(pattern))
    if not matches:
        raise FileNotFoundError(f"No files matching pattern '{pattern}' found under {root_path}")
    return str(matches[0].parent)


class LTX2TextEncoder:
    """
    Simplified text encoder loader that encapsulates all complex building logic.

    Usage:
        model = LTX2TextEncoder(
            checkpoint_path="/path/to/checkpoint.safetensors",
            gemma_root="/path/to/gemma",
            device=torch.device("cuda"),
            dtype=torch.bfloat16
        )

    This class handles:
    - Loading model configuration from checkpoint
    - Creating model structure
    - Loading Gemma model, tokenizer, and processor from gemma_root
    - Loading weights from checkpoint with key mapping
    - Moving to device and setting dtype
    """

    def __init__(
        self,
        checkpoint_path: str,
        gemma_root: str,
        device: torch.device,
        dtype: torch.dtype = torch.bfloat16,
        cpu_offload: bool = False,
    ):
        """
        Initialize the simplified text encoder loader.

        Args:
            checkpoint_path: Path to the checkpoint file containing text encoder weights
            gemma_root: Root directory containing Gemma model, tokenizer, and processor
            device: Target device for the model
            dtype: Data type for model parameters
        """
        self.checkpoint_path = checkpoint_path
        self.gemma_root = gemma_root
        self.device = device
        self.dtype = dtype
        self.cpu_offload = cpu_offload
        self.loader = SafetensorsModelStateDictLoader()
        self.text_encoder = self.load()

    def _load_gemma_model(self) -> Gemma3ForConditionalGeneration:
        """Load Gemma model from gemma_root."""
        gemma_path = _find_matching_dir(self.gemma_root, "model*.safetensors")
        return Gemma3ForConditionalGeneration.from_pretrained(gemma_path, local_files_only=True, torch_dtype=torch.bfloat16)

    def _load_tokenizer(self) -> LTXVGemmaTokenizer:
        """Load tokenizer from gemma_root."""
        tokenizer_path = _find_matching_dir(self.gemma_root, "tokenizer.model")
        return LTXVGemmaTokenizer(tokenizer_path, 1024)

    def _load_processor(self, tokenizer: LTXVGemmaTokenizer) -> Gemma3Processor:
        """Load processor from gemma_root."""
        processor_path = _find_matching_dir(self.gemma_root, "preprocessor_config.json")
        image_processor = AutoImageProcessor.from_pretrained(processor_path, local_files_only=True)
        return Gemma3Processor(image_processor=image_processor, tokenizer=tokenizer.tokenizer)

    def load(self) -> AVGemmaTextEncoderModel:
        """
        Load and build the text encoder model.

        Returns:
            AVGemmaTextEncoderModel: The fully initialized text encoder model
        """
        # Step 1: Load configuration from checkpoint

        config = self.loader.metadata(self.checkpoint_path)

        # Step 2: Create model structure (meta model)
        model = AVGemmaTextEncoderModelConfigurator.from_config(config)

        # Step 3: Load Gemma model, tokenizer, and processor from gemma_root
        model.model = self._load_gemma_model()
        model.tokenizer = self._load_tokenizer()
        model.processor = self._load_processor(model.tokenizer)

        # Step 4: Load weights from checkpoint with key mapping
        state_dict_obj = self.loader.load(
            self.checkpoint_path,
            sd_ops=AV_GEMMA_TEXT_ENCODER_KEY_OPS,
            device=self.device,
        )

        # Step 5: Apply dtype conversion if needed
        state_dict = state_dict_obj.sd
        if self.dtype is not None:
            state_dict = {key: value.to(dtype=self.dtype) for key, value in state_dict.items()}

        # Step 6: Load state dict into model
        model.load_state_dict(state_dict, strict=False, assign=True)

        # Step 7: Move to device and set eval mode
        model = model.to(self.device).eval()

        return model

    def encode_text(self, prompts: list[str]) -> list[tuple[torch.Tensor, torch.Tensor]]:
        """
        Encode a list of prompts using the provided Gemma text encoder.

        Args:
            text_encoder: The Gemma text encoder instance.
            prompts: List of prompt strings to encode.

        Returns:
            List of tuples, each containing (v_context, a_context) tensors for each prompt.
        """
        if self.cpu_offload:
            self.text_encoder = self.text_encoder.to(AI_DEVICE)
        result = []
        for prompt in prompts:
            v_context, a_context, _ = self.text_encoder(prompt)
            result.append((v_context, a_context))
        if self.cpu_offload:
            self.text_encoder = self.text_encoder.to("cpu")
        return result

    def infer(
        self,
        prompt: str,
        negative_prompt: str = "",
    ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        Infer text encoder outputs for prompt and negative prompt.

        This is a convenience function that encodes both prompt and negative prompt,
        and returns the video and audio contexts for both.

        Args:
            text_encoder: The Gemma text encoder instance.
            prompt: Positive prompt string.
            negative_prompt: Negative prompt string (default: empty string).

        Returns:
            Tuple containing:
            - v_context_p: Video context for positive prompt
            - a_context_p: Audio context for positive prompt
            - v_context_n: Video context for negative prompt
            - a_context_n: Audio context for negative prompt
        """
        contexts = self.encode_text(prompts=[prompt, negative_prompt])
        context_p, context_n = contexts
        v_context_p, a_context_p = context_p
        v_context_n, a_context_n = context_n
        return v_context_p, a_context_p, v_context_n, a_context_n

    def apply_lora(self, lora_configs):
        """
        Apply LoRA weights to text encoder's feature_extractor_linear.

        Args:
            lora_configs: List of LoRA configuration dicts, each containing:
                - path: Path to LoRA safetensors file
                - strength: LoRA strength (default: 1.0)

        Returns:
            bool: True if LoRA was successfully applied, False otherwise
        """
        if not hasattr(self, "text_encoder"):
            logger.warning("Text encoder does not have expected structure. Skipping LoRA application.")
            return False

        encoder_model = self.text_encoder

        # Get the feature_extractor_linear module
        if not hasattr(encoder_model, "feature_extractor_linear"):
            logger.warning("Text encoder does not have feature_extractor_linear. Skipping LoRA application.")
            return False

        feature_extractor = encoder_model.feature_extractor_linear
        if not hasattr(feature_extractor, "aggregate_embed"):
            logger.warning("feature_extractor_linear does not have aggregate_embed. Skipping LoRA application.")
            return False

        # Create a weight dict for the feature extractor
        # The key should match what LoRA loader expects after mapping
        weight_dict = {"feature_extractor_linear.aggregate_embed.weight": feature_extractor.aggregate_embed.weight.data.clone()}

        # Create LoRALoader without model_prefix (text encoder keys don't need it)
        # Map text_embedding_projection. to feature_extractor_linear.
        key_mapping_rules = [
            (r"^text_embedding_projection\.", "feature_extractor_linear."),
        ]
        lora_loader = LoRALoader(key_mapping_rules=key_mapping_rules)

        for lora_config in lora_configs:
            lora_path = lora_config["path"]
            lora_strength = lora_config.get("strength", 1.0)

            # Load only text_embedding_projection keys to save memory
            with safe_open(lora_path, framework="pt") as f:
                # First, get all keys and filter for text_embedding_projection
                all_keys = list(f.keys())
                text_encoder_keys = [key for key in all_keys if key.startswith("text_embedding_projection.")]

                # Only load the filtered keys
                text_encoder_lora_weights = {key: f.get_tensor(key).to(GET_DTYPE()).to(self.device) for key in text_encoder_keys}

            if text_encoder_lora_weights:
                # Apply LoRA to feature extractor
                applied_count = lora_loader.apply_lora(
                    weight_dict=weight_dict,
                    lora_weights=text_encoder_lora_weights,
                    strength=lora_strength,
                )

                if applied_count > 0:
                    # Update the actual model weights
                    feature_extractor.aggregate_embed.weight.data = weight_dict["feature_extractor_linear.aggregate_embed.weight"]
                    logger.info(f"Successfully applied {applied_count} LoRA weights to text encoder from {lora_path} (strength: {lora_strength})")
                else:
                    logger.warning(f"No LoRA weights were applied to text encoder from {lora_path}")
            else:
                logger.debug(f"No text_embedding_projection LoRA keys found in {lora_path}")

            del text_encoder_lora_weights, weight_dict
            gc.collect()

        return True


if __name__ == "__main__":
    DEFAULT_NEGATIVE_PROMPT = (
        "blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, "
        "grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, "
        "deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, "
        "wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of "
        "field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent "
        "lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny "
        "valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, "
        "mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, "
        "off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward "
        "pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, "
        "inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts."
    )

    model = LTX2TextEncoder(
        checkpoint_path="/data/nvme0/gushiqiao/models/official_models/LTX-2/ltx-2-19b-distilled-fp8.safetensors",
        gemma_root="/data/nvme0/gushiqiao/models/official_models/LTX-2",
        device="cuda",
        dtype=torch.bfloat16,
    )

    v_context_p, a_context_p, v_context_n, a_context_n = model.infer(
        prompt="A beautiful sunset over the ocean",
        negative_prompt=DEFAULT_NEGATIVE_PROMPT,
    )