qwen_omni_merge.py 5.07 KB
Newer Older
chenych's avatar
chenych committed
1
# Copyright 2025 the LlamaFactory team.
chenych's avatar
chenych committed
2
3
4
5
6
7
8
9
10
11
12
13
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
chenych's avatar
chenych committed
14

chenych's avatar
chenych committed
15
16
17
18
19
import os
import shutil

import fire
from peft import PeftModel
chenych's avatar
chenych committed
20
from transformers import AutoModel, AutoProcessor, Qwen2_5OmniThinkerForConditionalGeneration  # type: ignore
chenych's avatar
chenych committed
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41


def merge_lora(
    base_model_path: str,
    lora_checkpoint_path: str,
    extra_file: str = "spk_dict.pt",
    submodule_name: str = "thinker",
    save_path: str = "./merged_model_checkpoint",
):
    """Load the original model, tokenizer, and processor configuration, merge the LoRA weights.

    For a specified submodule, and save the final merged model along with its configurations.

    Args:
        base_model_path (str): Path to the original model directory.
        lora_checkpoint_path (str): Path to the directory containing LoRA weights.
        extra_file (str): Name of the extra file to be copied (default: "spk_dict.pt").
        submodule_name (str): Name of the submodule to merge (default: "thinker").
        save_path (str): Directory where the merged model and configurations will be saved.
    """
    # 1. Load the original model, tokenizer, and processor
chenych's avatar
chenych committed
42
43
44
    model = AutoModel.from_pretrained(base_model_path, torch_dtype="auto", device_map="cpu")
    processor = AutoProcessor.from_pretrained(base_model_path)
    print("Successfully loaded the original model and tokenizer.")
chenych's avatar
chenych committed
45
46
47
48

    # 2. Extract the submodule to be merged (e.g., model.thinker)
    if not hasattr(model, submodule_name):
        raise AttributeError(f"The model does not have a submodule named '{submodule_name}'.")
chenych's avatar
chenych committed
49

chenych's avatar
chenych committed
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
    base_submodule = getattr(model, submodule_name)
    print(f"Successfully extracted submodule: {submodule_name}.")

    # 3. Load the LoRA weights onto the extracted submodule
    lora_model = PeftModel.from_pretrained(base_submodule, lora_checkpoint_path)
    print("LoRA weights loaded successfully.")

    # 4. Merge the LoRA weights into the submodule and unload the LoRA modules
    merged_submodule = lora_model.merge_and_unload()
    print("LoRA weights merged successfully.")

    # 5. Replace the original submodule with the merged submodule in the model
    setattr(model, submodule_name, merged_submodule)

    # 6. Save the final merged model along with the tokenizer and processor configuration
    model.save_pretrained(save_path)
chenych's avatar
chenych committed
66
67
    processor.save_pretrained(save_path)
    print(f"Merged model and tokenizer saved to {save_path}.")
chenych's avatar
chenych committed
68
69
70
71
72
73
74
75
76
77
78
79
80

    source_file = os.path.join(base_model_path, extra_file)
    target_file = os.path.join(save_path, extra_file)
    if os.path.exists(source_file):
        shutil.copy(source_file, target_file)
        print(f"File '{extra_file}' copied from {base_model_path} to {save_path}.")
    else:
        print(f"File '{extra_file}' not found in {base_model_path}, skipping copy.")


def save_full_model(
    saved_thinker_path: str,
    base_model_path: str,
chenych's avatar
chenych committed
81
    save_path: str = "./merged_model_checkpoint",
chenych's avatar
chenych committed
82
83
84
85
86
87
88
89
90
    extra_file: str = "spk_dict.pt",
):
    """Load the saved thinker module and the original model, replace the thinker in the original model.

    Then save the complete model along with its tokenizer and processor configuration.

    Args:
        saved_thinker_path (str): Path to the saved thinker weights.
        base_model_path (str): Directory path of the original model.
chenych's avatar
chenych committed
91
        save_path (str): Directory where the merged model and configurations will be saved.
chenych's avatar
chenych committed
92
93
        extra_file (str): Name of the extra file to be copied (default: "spk_dict.pt").
    """
chenych's avatar
chenych committed
94
95
96
97
98
    # 1. Load the saved thinker module and the original model
    thinker = Qwen2_5OmniThinkerForConditionalGeneration.from_pretrained(
        saved_thinker_path, torch_dtype="auto", device_map="cpu"
    )
    base_model = AutoModel.from_pretrained(base_model_path, torch_dtype="auto", device_map="cpu")
chenych's avatar
chenych committed
99
100
    base_model.thinker = thinker

chenych's avatar
chenych committed
101
102
    # 2. Save the complete model along with its tokenizer and processor configuration
    processor = AutoProcessor.from_pretrained(base_model_path)
chenych's avatar
chenych committed
103
104
    base_model.save_pretrained(save_path)
    processor.save_pretrained(save_path)
chenych's avatar
chenych committed
105
    print(f"Merged model and tokenizer saved to {save_path}.")
chenych's avatar
chenych committed
106

chenych's avatar
chenych committed
107
    # 3. Copy the extra file from the base model directory to the save_path
chenych's avatar
chenych committed
108
109
110
111
112
113
114
115
116
117
118
    source_file = os.path.join(base_model_path, extra_file)
    target_file = os.path.join(save_path, extra_file)
    if os.path.exists(source_file):
        shutil.copy(source_file, target_file)
        print(f"File '{extra_file}' copied from {base_model_path} to {save_path}.")
    else:
        print(f"File '{extra_file}' not found in {base_model_path}, skipping copy.")


if __name__ == "__main__":
    fire.Fire({"save_full": save_full_model, "merge_lora": merge_lora})