lora_adapter.py 5.08 KB
Newer Older
PengGao's avatar
PengGao committed
1
import gc
lijiaqi2's avatar
lijiaqi2 committed
2
import os
PengGao's avatar
PengGao committed
3

lijiaqi2's avatar
lijiaqi2 committed
4
5
import torch
from loguru import logger
PengGao's avatar
PengGao committed
6
7
from safetensors import safe_open

gushiqiao's avatar
gushiqiao committed
8
from lightx2v.utils.envs import *
lijiaqi2's avatar
lijiaqi2 committed
9
10
11
12
13


class WanLoraWrapper:
    def __init__(self, wan_model):
        self.model = wan_model
14
15
        self.lora_metadata = {}
        self.override_dict = {}  # On CPU
lijiaqi2's avatar
lijiaqi2 committed
16
17
18
19
20

    def load_lora(self, lora_path, lora_name=None):
        if lora_name is None:
            lora_name = os.path.basename(lora_path).split(".")[0]

21
        if lora_name in self.lora_metadata:
lijiaqi2's avatar
lijiaqi2 committed
22
23
24
            logger.info(f"LoRA {lora_name} already loaded, skipping...")
            return lora_name

25
26
        self.lora_metadata[lora_name] = {"path": lora_path}
        logger.info(f"Registered LoRA metadata for: {lora_name} from {lora_path}")
lijiaqi2's avatar
lijiaqi2 committed
27
28
29
30

        return lora_name

    def _load_lora_file(self, file_path):
gushiqiao's avatar
gushiqiao committed
31
        use_bfloat16 = GET_DTYPE() == "BF16"
lijiaqi2's avatar
lijiaqi2 committed
32
33
34
35
36
37
38
39
40
41
        if self.model.config and hasattr(self.model.config, "get"):
            use_bfloat16 = self.model.config.get("use_bfloat16", True)
        with safe_open(file_path, framework="pt") as f:
            if use_bfloat16:
                tensor_dict = {key: f.get_tensor(key).to(torch.bfloat16) for key in f.keys()}
            else:
                tensor_dict = {key: f.get_tensor(key) for key in f.keys()}
        return tensor_dict

    def apply_lora(self, lora_name, alpha=1.0):
42
        if lora_name not in self.lora_metadata:
lijiaqi2's avatar
lijiaqi2 committed
43
44
45
46
47
48
            logger.info(f"LoRA {lora_name} not found. Please load it first.")

        if not hasattr(self.model, "original_weight_dict"):
            logger.error("Model does not have 'original_weight_dict'. Cannot apply LoRA.")
            return False

49
        lora_weights = self._load_lora_file(self.lora_metadata[lora_name]["path"])
lijiaqi2's avatar
lijiaqi2 committed
50
51
        weight_dict = self.model.original_weight_dict
        self._apply_lora_weights(weight_dict, lora_weights, alpha)
52
        self.model._init_weights(weight_dict)
lijiaqi2's avatar
lijiaqi2 committed
53
54

        logger.info(f"Applied LoRA: {lora_name} with alpha={alpha}")
wangshankun's avatar
wangshankun committed
55
        del lora_weights  # 删除节约显存
lijiaqi2's avatar
lijiaqi2 committed
56
57
        return True

58
    @torch.no_grad()
lijiaqi2's avatar
lijiaqi2 committed
59
60
    def _apply_lora_weights(self, weight_dict, lora_weights, alpha):
        lora_pairs = {}
GoatWu's avatar
GoatWu committed
61
        lora_diffs = {}
lijiaqi2's avatar
lijiaqi2 committed
62
63
        prefix = "diffusion_model."

GoatWu's avatar
GoatWu committed
64
65
66
67
68
69
70
71
72
73
74
75
        def try_lora_pair(key, suffix_a, suffix_b, target_suffix):
            if key.endswith(suffix_a):
                base_name = key[len(prefix) :].replace(suffix_a, target_suffix)
                pair_key = key.replace(suffix_a, suffix_b)
                if pair_key in lora_weights:
                    lora_pairs[base_name] = (key, pair_key)

        def try_lora_diff(key, suffix, target_suffix):
            if key.endswith(suffix):
                base_name = key[len(prefix) :].replace(suffix, target_suffix)
                lora_diffs[base_name] = key

lijiaqi2's avatar
lijiaqi2 committed
76
        for key in lora_weights.keys():
GoatWu's avatar
GoatWu committed
77
78
79
80
81
82
83
            if not key.startswith(prefix):
                continue

            try_lora_pair(key, "lora_A.weight", "lora_B.weight", "weight")
            try_lora_pair(key, "lora_down.weight", "lora_up.weight", "weight")
            try_lora_diff(key, "diff", "weight")
            try_lora_diff(key, "diff_b", "bias")
84
            try_lora_diff(key, "diff_m", "modulation")
lijiaqi2's avatar
lijiaqi2 committed
85
86
87
88

        applied_count = 0
        for name, param in weight_dict.items():
            if name in lora_pairs:
89
90
                if name not in self.override_dict:
                    self.override_dict[name] = param.clone().cpu()
lijiaqi2's avatar
lijiaqi2 committed
91
92
93
                name_lora_A, name_lora_B = lora_pairs[name]
                lora_A = lora_weights[name_lora_A].to(param.device, param.dtype)
                lora_B = lora_weights[name_lora_B].to(param.device, param.dtype)
Zhuguanyu Wu's avatar
Zhuguanyu Wu committed
94
95
96
                if param.shape == (lora_B.shape[0], lora_A.shape[1]):
                    param += torch.matmul(lora_B, lora_A) * alpha
                    applied_count += 1
GoatWu's avatar
GoatWu committed
97
98
99
100
101
102
            elif name in lora_diffs:
                if name not in self.override_dict:
                    self.override_dict[name] = param.clone().cpu()

                name_diff = lora_diffs[name]
                lora_diff = lora_weights[name_diff].to(param.device, param.dtype)
Zhuguanyu Wu's avatar
Zhuguanyu Wu committed
103
104
105
                if param.shape == lora_diff.shape:
                    param += lora_diff * alpha
                    applied_count += 1
lijiaqi2's avatar
lijiaqi2 committed
106
107
108
109
110
111
112

        logger.info(f"Applied {applied_count} LoRA weight adjustments")
        if applied_count == 0:
            logger.info(
                "Warning: No LoRA weights were applied. Expected naming conventions: 'diffusion_model.<layer_name>.lora_A.weight' and 'diffusion_model.<layer_name>.lora_B.weight'. Please verify the LoRA weight file."
            )

113
    @torch.no_grad()
lijiaqi2's avatar
lijiaqi2 committed
114
    def remove_lora(self):
GoatWu's avatar
GoatWu committed
115
        logger.info(f"Removing LoRA ...")
lijiaqi2's avatar
lijiaqi2 committed
116
117
118
119
120
121

        restored_count = 0
        for k, v in self.override_dict.items():
            self.model.original_weight_dict[k] = v.to(self.model.device)
            restored_count += 1

GoatWu's avatar
GoatWu committed
122
        logger.info(f"LoRA removed, restored {restored_count} weights")
lijiaqi2's avatar
lijiaqi2 committed
123

124
        self.model._init_weights(self.model.original_weight_dict)
lijiaqi2's avatar
lijiaqi2 committed
125
126
127
128

        torch.cuda.empty_cache()
        gc.collect()

GoatWu's avatar
GoatWu committed
129
        self.lora_metadata = {}
130
131
        self.override_dict = {}

lijiaqi2's avatar
lijiaqi2 committed
132
    def list_loaded_loras(self):
133
        return list(self.lora_metadata.keys())