merge_lora.py 3.68 KB
Newer Older
chenzk's avatar
v1.0  
chenzk committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
# Copyright (c) Alibaba, Inc. and its affiliates.
import torch
import os
import re
from collections import defaultdict
from safetensors.torch import load_file
from modelscope.utils.import_utils import is_swift_available


def merge_lora(pipeline, lora_path, multiplier, from_safetensor=False, device='cpu', dtype=torch.float32):
    LORA_PREFIX_UNET = "lora_unet"
    LORA_PREFIX_TEXT_ENCODER = "lora_te"
    print ('----------')
    print ('Lora Path: ', lora_path)
    if from_safetensor:
        state_dict = load_file(lora_path, device=device)
    elif os.path.exists(os.path.join(lora_path, 'swift')):
        if not is_swift_available():
                    raise ValueError(
                        'Please install swift by `pip install ms-swift` to use efficient_tuners.'
                    )
        from swift import Swift
        pipeline.unet = Swift.from_pretrained(pipeline.unet, os.path.join(lora_path, 'swift'))
        return pipeline
    else:
        if os.path.exists(os.path.join(lora_path, 'pytorch_lora_weights.bin')):
            checkpoint = torch.load(os.path.join(lora_path, 'pytorch_lora_weights.bin'), map_location=torch.device(device))
        elif os.path.exists(os.path.join(lora_path, 'pytorch_lora_weights.safetensors')):
            checkpoint= load_file(os.path.join(lora_path,'pytorch_lora_weights.safetensors'), device=device)
        new_dict = dict()
        for idx, key in enumerate(checkpoint):
            new_key = re.sub(r'\.processor\.', '_', key)
            new_key = re.sub(r'mid_block\.', 'mid_block_', new_key)
            new_key = re.sub('_lora.up.', '.lora_up.', new_key)
            new_key = re.sub('_lora.down.', '.lora_down.', new_key)
            new_key = re.sub(r'\.(\d+)\.', '_\\1_', new_key)
            new_key = re.sub('to_out', 'to_out_0', new_key)
            new_key = 'lora_unet_' + new_key
            new_dict[new_key] = checkpoint[key]
            state_dict = new_dict
    updates = defaultdict(dict)
    for key, value in state_dict.items():
        layer, elem = key.split('.', 1)
        updates[layer][elem] = value

    for layer, elems in updates.items():

        if "text" in layer:
            layer_infos = layer.split(LORA_PREFIX_TEXT_ENCODER + "_")[-1].split("_")
            curr_layer = pipeline.text_encoder
        else:
            layer_infos = layer.split(LORA_PREFIX_UNET + "_")[-1].split("_")
            curr_layer = pipeline.unet

        temp_name = layer_infos.pop(0)
        while len(layer_infos) > -1:
            try:
                curr_layer = curr_layer.__getattr__(temp_name)
                if len(layer_infos) > 0:
                    temp_name = layer_infos.pop(0)
                elif len(layer_infos) == 0:
                    break
            except Exception:
                if len(layer_infos) == 0:
                    print('Error loading layer')
                if len(temp_name) > 0:
                    temp_name += "_" + layer_infos.pop(0)
                else:
                    temp_name = layer_infos.pop(0)

        weight_up = elems['lora_up.weight'].to(dtype)
        weight_down = elems['lora_down.weight'].to(dtype)
        if 'alpha' in elems.keys():
            alpha = elems['alpha'].item() / weight_up.shape[1]
        else:
            alpha = 1.0

        curr_layer.weight.data = curr_layer.weight.data.to(device)
        if len(weight_up.shape) == 4:
            curr_layer.weight.data += multiplier * alpha * torch.mm(weight_up.squeeze(3).squeeze(2),
                                                                    weight_down.squeeze(3).squeeze(2)).unsqueeze(
                2).unsqueeze(3)
        else:
            curr_layer.weight.data += multiplier * alpha * torch.mm(weight_up, weight_down)

    return pipeline