lora_adapter.py 4.45 KB
Newer Older
lijiaqi2's avatar
lijiaqi2 committed
1
2
3
4
5
import os
import torch
from safetensors import safe_open
from loguru import logger
import gc
gushiqiao's avatar
gushiqiao committed
6
from lightx2v.utils.envs import *
lijiaqi2's avatar
lijiaqi2 committed
7
8
9
10
11


class WanLoraWrapper:
    def __init__(self, wan_model):
        self.model = wan_model
12
13
        self.lora_metadata = {}
        self.override_dict = {}  # On CPU
lijiaqi2's avatar
lijiaqi2 committed
14
15
16
17
18

    def load_lora(self, lora_path, lora_name=None):
        if lora_name is None:
            lora_name = os.path.basename(lora_path).split(".")[0]

19
        if lora_name in self.lora_metadata:
lijiaqi2's avatar
lijiaqi2 committed
20
21
22
            logger.info(f"LoRA {lora_name} already loaded, skipping...")
            return lora_name

23
24
        self.lora_metadata[lora_name] = {"path": lora_path}
        logger.info(f"Registered LoRA metadata for: {lora_name} from {lora_path}")
lijiaqi2's avatar
lijiaqi2 committed
25
26
27
28

        return lora_name

    def _load_lora_file(self, file_path):
gushiqiao's avatar
gushiqiao committed
29
        use_bfloat16 = GET_DTYPE() == "BF16"
lijiaqi2's avatar
lijiaqi2 committed
30
31
32
33
34
35
36
37
38
39
        if self.model.config and hasattr(self.model.config, "get"):
            use_bfloat16 = self.model.config.get("use_bfloat16", True)
        with safe_open(file_path, framework="pt") as f:
            if use_bfloat16:
                tensor_dict = {key: f.get_tensor(key).to(torch.bfloat16) for key in f.keys()}
            else:
                tensor_dict = {key: f.get_tensor(key) for key in f.keys()}
        return tensor_dict

    def apply_lora(self, lora_name, alpha=1.0):
40
        if lora_name not in self.lora_metadata:
lijiaqi2's avatar
lijiaqi2 committed
41
42
43
44
45
46
47
48
49
            logger.info(f"LoRA {lora_name} not found. Please load it first.")

        if hasattr(self.model, "current_lora") and self.model.current_lora:
            self.remove_lora()

        if not hasattr(self.model, "original_weight_dict"):
            logger.error("Model does not have 'original_weight_dict'. Cannot apply LoRA.")
            return False

50
        lora_weights = self._load_lora_file(self.lora_metadata[lora_name]["path"])
lijiaqi2's avatar
lijiaqi2 committed
51
52
        weight_dict = self.model.original_weight_dict
        self._apply_lora_weights(weight_dict, lora_weights, alpha)
53
        self.model._init_weights(weight_dict)
lijiaqi2's avatar
lijiaqi2 committed
54
55
56
57
58

        self.model.current_lora = lora_name
        logger.info(f"Applied LoRA: {lora_name} with alpha={alpha}")
        return True

59
    @torch.no_grad()
lijiaqi2's avatar
lijiaqi2 committed
60
61
62
63
64
65
66
67
68
69
70
71
72
73
    def _apply_lora_weights(self, weight_dict, lora_weights, alpha):
        lora_pairs = {}
        prefix = "diffusion_model."

        for key in lora_weights.keys():
            if key.endswith("lora_A.weight") and key.startswith(prefix):
                base_name = key[len(prefix) :].replace("lora_A.weight", "weight")
                b_key = key.replace("lora_A.weight", "lora_B.weight")
                if b_key in lora_weights:
                    lora_pairs[base_name] = (key, b_key)

        applied_count = 0
        for name, param in weight_dict.items():
            if name in lora_pairs:
74
75
76
                if name not in self.override_dict:
                    self.override_dict[name] = param.clone().cpu()

lijiaqi2's avatar
lijiaqi2 committed
77
78
79
80
81
82
83
84
85
86
87
88
                name_lora_A, name_lora_B = lora_pairs[name]
                lora_A = lora_weights[name_lora_A].to(param.device, param.dtype)
                lora_B = lora_weights[name_lora_B].to(param.device, param.dtype)
                param += torch.matmul(lora_B, lora_A) * alpha
                applied_count += 1

        logger.info(f"Applied {applied_count} LoRA weight adjustments")
        if applied_count == 0:
            logger.info(
                "Warning: No LoRA weights were applied. Expected naming conventions: 'diffusion_model.<layer_name>.lora_A.weight' and 'diffusion_model.<layer_name>.lora_B.weight'. Please verify the LoRA weight file."
            )

89
    @torch.no_grad()
lijiaqi2's avatar
lijiaqi2 committed
90
91
92
93
94
95
96
97
98
99
100
101
102
    def remove_lora(self):
        if not self.model.current_lora:
            logger.info("No LoRA currently applied")
            return
        logger.info(f"Removing LoRA {self.model.current_lora}...")

        restored_count = 0
        for k, v in self.override_dict.items():
            self.model.original_weight_dict[k] = v.to(self.model.device)
            restored_count += 1

        logger.info(f"LoRA {self.model.current_lora} removed, restored {restored_count} weights")

103
        self.model._init_weights(self.model.original_weight_dict)
lijiaqi2's avatar
lijiaqi2 committed
104
105
106
107

        torch.cuda.empty_cache()
        gc.collect()

108
109
110
111
112
        if self.model.current_lora and self.model.current_lora in self.lora_metadata:
            del self.lora_metadata[self.model.current_lora]
        self.override_dict = {}
        self.model.current_lora = None

lijiaqi2's avatar
lijiaqi2 committed
113
    def list_loaded_loras(self):
114
        return list(self.lora_metadata.keys())
lijiaqi2's avatar
lijiaqi2 committed
115
116
117

    def get_current_lora(self):
        return self.model.current_lora