lora_adapter.py 4.75 KB
Newer Older
PengGao's avatar
PengGao committed
1
import gc
lijiaqi2's avatar
lijiaqi2 committed
2
import os
PengGao's avatar
PengGao committed
3

lijiaqi2's avatar
lijiaqi2 committed
4
5
import torch
from loguru import logger
PengGao's avatar
PengGao committed
6
7
from safetensors import safe_open

gushiqiao's avatar
gushiqiao committed
8
from lightx2v.utils.envs import *
lijiaqi2's avatar
lijiaqi2 committed
9
10
11
12
13


class WanLoraWrapper:
    def __init__(self, wan_model):
        self.model = wan_model
14
15
        self.lora_metadata = {}
        self.override_dict = {}  # On CPU
lijiaqi2's avatar
lijiaqi2 committed
16
17
18
19
20

    def load_lora(self, lora_path, lora_name=None):
        if lora_name is None:
            lora_name = os.path.basename(lora_path).split(".")[0]

21
        if lora_name in self.lora_metadata:
lijiaqi2's avatar
lijiaqi2 committed
22
23
24
            logger.info(f"LoRA {lora_name} already loaded, skipping...")
            return lora_name

25
26
        self.lora_metadata[lora_name] = {"path": lora_path}
        logger.info(f"Registered LoRA metadata for: {lora_name} from {lora_path}")
lijiaqi2's avatar
lijiaqi2 committed
27
28
29
30
31

        return lora_name

    def _load_lora_file(self, file_path):
        with safe_open(file_path, framework="pt") as f:
gushiqiao's avatar
gushiqiao committed
32
            tensor_dict = {key: f.get_tensor(key).to(GET_DTYPE()) for key in f.keys()}
lijiaqi2's avatar
lijiaqi2 committed
33
34
35
        return tensor_dict

    def apply_lora(self, lora_name, alpha=1.0):
36
        if lora_name not in self.lora_metadata:
lijiaqi2's avatar
lijiaqi2 committed
37
38
39
40
41
42
            logger.info(f"LoRA {lora_name} not found. Please load it first.")

        if not hasattr(self.model, "original_weight_dict"):
            logger.error("Model does not have 'original_weight_dict'. Cannot apply LoRA.")
            return False

43
        lora_weights = self._load_lora_file(self.lora_metadata[lora_name]["path"])
lijiaqi2's avatar
lijiaqi2 committed
44
45
        weight_dict = self.model.original_weight_dict
        self._apply_lora_weights(weight_dict, lora_weights, alpha)
46
        self.model._init_weights(weight_dict)
lijiaqi2's avatar
lijiaqi2 committed
47
48

        logger.info(f"Applied LoRA: {lora_name} with alpha={alpha}")
gushiqiao's avatar
gushiqiao committed
49
        del lora_weights
lijiaqi2's avatar
lijiaqi2 committed
50
51
        return True

52
    @torch.no_grad()
lijiaqi2's avatar
lijiaqi2 committed
53
54
    def _apply_lora_weights(self, weight_dict, lora_weights, alpha):
        lora_pairs = {}
GoatWu's avatar
GoatWu committed
55
        lora_diffs = {}
lijiaqi2's avatar
lijiaqi2 committed
56
57
        prefix = "diffusion_model."

GoatWu's avatar
GoatWu committed
58
59
60
61
62
63
64
65
66
67
68
69
        def try_lora_pair(key, suffix_a, suffix_b, target_suffix):
            if key.endswith(suffix_a):
                base_name = key[len(prefix) :].replace(suffix_a, target_suffix)
                pair_key = key.replace(suffix_a, suffix_b)
                if pair_key in lora_weights:
                    lora_pairs[base_name] = (key, pair_key)

        def try_lora_diff(key, suffix, target_suffix):
            if key.endswith(suffix):
                base_name = key[len(prefix) :].replace(suffix, target_suffix)
                lora_diffs[base_name] = key

lijiaqi2's avatar
lijiaqi2 committed
70
        for key in lora_weights.keys():
GoatWu's avatar
GoatWu committed
71
72
73
74
75
76
77
            if not key.startswith(prefix):
                continue

            try_lora_pair(key, "lora_A.weight", "lora_B.weight", "weight")
            try_lora_pair(key, "lora_down.weight", "lora_up.weight", "weight")
            try_lora_diff(key, "diff", "weight")
            try_lora_diff(key, "diff_b", "bias")
78
            try_lora_diff(key, "diff_m", "modulation")
lijiaqi2's avatar
lijiaqi2 committed
79
80
81
82

        applied_count = 0
        for name, param in weight_dict.items():
            if name in lora_pairs:
83
84
                if name not in self.override_dict:
                    self.override_dict[name] = param.clone().cpu()
lijiaqi2's avatar
lijiaqi2 committed
85
86
87
                name_lora_A, name_lora_B = lora_pairs[name]
                lora_A = lora_weights[name_lora_A].to(param.device, param.dtype)
                lora_B = lora_weights[name_lora_B].to(param.device, param.dtype)
Zhuguanyu Wu's avatar
Zhuguanyu Wu committed
88
89
90
                if param.shape == (lora_B.shape[0], lora_A.shape[1]):
                    param += torch.matmul(lora_B, lora_A) * alpha
                    applied_count += 1
GoatWu's avatar
GoatWu committed
91
92
93
94
95
96
            elif name in lora_diffs:
                if name not in self.override_dict:
                    self.override_dict[name] = param.clone().cpu()

                name_diff = lora_diffs[name]
                lora_diff = lora_weights[name_diff].to(param.device, param.dtype)
Zhuguanyu Wu's avatar
Zhuguanyu Wu committed
97
98
99
                if param.shape == lora_diff.shape:
                    param += lora_diff * alpha
                    applied_count += 1
lijiaqi2's avatar
lijiaqi2 committed
100
101
102
103
104
105
106

        logger.info(f"Applied {applied_count} LoRA weight adjustments")
        if applied_count == 0:
            logger.info(
                "Warning: No LoRA weights were applied. Expected naming conventions: 'diffusion_model.<layer_name>.lora_A.weight' and 'diffusion_model.<layer_name>.lora_B.weight'. Please verify the LoRA weight file."
            )

107
    @torch.no_grad()
lijiaqi2's avatar
lijiaqi2 committed
108
    def remove_lora(self):
GoatWu's avatar
GoatWu committed
109
        logger.info(f"Removing LoRA ...")
lijiaqi2's avatar
lijiaqi2 committed
110
111
112
113
114
115

        restored_count = 0
        for k, v in self.override_dict.items():
            self.model.original_weight_dict[k] = v.to(self.model.device)
            restored_count += 1

GoatWu's avatar
GoatWu committed
116
        logger.info(f"LoRA removed, restored {restored_count} weights")
lijiaqi2's avatar
lijiaqi2 committed
117

118
        self.model._init_weights(self.model.original_weight_dict)
lijiaqi2's avatar
lijiaqi2 committed
119
120
121
122

        torch.cuda.empty_cache()
        gc.collect()

GoatWu's avatar
GoatWu committed
123
        self.lora_metadata = {}
124
125
        self.override_dict = {}

lijiaqi2's avatar
lijiaqi2 committed
126
    def list_loaded_loras(self):
127
        return list(self.lora_metadata.keys())