qwen_omni_merge.py 5.12 KB
Newer Older
chenych's avatar
chenych committed
1
# Copyright 2025 the LlamaFactory team.
chenych's avatar
chenych committed
2
3
4
5
6
7
8
9
10
11
12
13
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
chenych's avatar
chenych committed
14

chenych's avatar
chenych committed
15
16
17
18
19
import os
import shutil

import fire
from peft import PeftModel
chenych's avatar
chenych committed
20
21
22
23
24
from transformers import (
    AutoProcessor,
    Qwen2_5OmniForConditionalGeneration,  # type: ignore
    Qwen2_5OmniThinkerForConditionalGeneration,
)
chenych's avatar
chenych committed
25
26
27
28
29
30
31
32
33


def merge_lora(
    base_model_path: str,
    lora_checkpoint_path: str,
    extra_file: str = "spk_dict.pt",
    submodule_name: str = "thinker",
    save_path: str = "./merged_model_checkpoint",
):
chenych's avatar
chenych committed
34
    """Load the original model, merge the LoRA weights.
chenych's avatar
chenych committed
35
36
37
38
39
40
41
42
43
44

    For a specified submodule, and save the final merged model along with its configurations.

    Args:
        base_model_path (str): Path to the original model directory.
        lora_checkpoint_path (str): Path to the directory containing LoRA weights.
        extra_file (str): Name of the extra file to be copied (default: "spk_dict.pt").
        submodule_name (str): Name of the submodule to merge (default: "thinker").
        save_path (str): Directory where the merged model and configurations will be saved.
    """
chenych's avatar
chenych committed
45
46
47
    # 1. Load the original model
    model = Qwen2_5OmniForConditionalGeneration.from_pretrained(base_model_path, torch_dtype="auto", device_map="cpu")
    print("Successfully loaded the original model.")
chenych's avatar
chenych committed
48
49
50
51

    # 2. Extract the submodule to be merged (e.g., model.thinker)
    if not hasattr(model, submodule_name):
        raise AttributeError(f"The model does not have a submodule named '{submodule_name}'.")
chenych's avatar
chenych committed
52

chenych's avatar
chenych committed
53
54
55
56
57
    base_submodule = getattr(model, submodule_name)
    print(f"Successfully extracted submodule: {submodule_name}.")

    # 3. Load the LoRA weights onto the extracted submodule
    lora_model = PeftModel.from_pretrained(base_submodule, lora_checkpoint_path)
chenych's avatar
chenych committed
58
59
    processor = AutoProcessor.from_pretrained(lora_checkpoint_path)
    print("LoRA weights and processor loaded successfully.")
chenych's avatar
chenych committed
60
61
62
63
64
65
66
67
68
69

    # 4. Merge the LoRA weights into the submodule and unload the LoRA modules
    merged_submodule = lora_model.merge_and_unload()
    print("LoRA weights merged successfully.")

    # 5. Replace the original submodule with the merged submodule in the model
    setattr(model, submodule_name, merged_submodule)

    # 6. Save the final merged model along with the tokenizer and processor configuration
    model.save_pretrained(save_path)
chenych's avatar
chenych committed
70
71
    processor.save_pretrained(save_path)
    print(f"Merged model and tokenizer saved to {save_path}.")
chenych's avatar
chenych committed
72
73
74
75
76
77
78
79
80
81
82
83
84

    source_file = os.path.join(base_model_path, extra_file)
    target_file = os.path.join(save_path, extra_file)
    if os.path.exists(source_file):
        shutil.copy(source_file, target_file)
        print(f"File '{extra_file}' copied from {base_model_path} to {save_path}.")
    else:
        print(f"File '{extra_file}' not found in {base_model_path}, skipping copy.")


def save_full_model(
    saved_thinker_path: str,
    base_model_path: str,
chenych's avatar
chenych committed
85
    save_path: str = "./merged_model_checkpoint",
chenych's avatar
chenych committed
86
87
88
89
90
91
92
93
94
    extra_file: str = "spk_dict.pt",
):
    """Load the saved thinker module and the original model, replace the thinker in the original model.

    Then save the complete model along with its tokenizer and processor configuration.

    Args:
        saved_thinker_path (str): Path to the saved thinker weights.
        base_model_path (str): Directory path of the original model.
chenych's avatar
chenych committed
95
        save_path (str): Directory where the merged model and configurations will be saved.
chenych's avatar
chenych committed
96
97
        extra_file (str): Name of the extra file to be copied (default: "spk_dict.pt").
    """
chenych's avatar
chenych committed
98
99
100
101
    # 1. Load the saved thinker module and the original model
    thinker = Qwen2_5OmniThinkerForConditionalGeneration.from_pretrained(
        saved_thinker_path, torch_dtype="auto", device_map="cpu"
    )
chenych's avatar
chenych committed
102
103
104
    base_model = Qwen2_5OmniForConditionalGeneration.from_pretrained(
        base_model_path, torch_dtype="auto", device_map="cpu"
    )
chenych's avatar
chenych committed
105
106
    base_model.thinker = thinker

chenych's avatar
chenych committed
107
    # 2. Save the complete model along with its tokenizer and processor configuration
chenych's avatar
chenych committed
108
    processor = AutoProcessor.from_pretrained(saved_thinker_path)
chenych's avatar
chenych committed
109
110
    base_model.save_pretrained(save_path)
    processor.save_pretrained(save_path)
chenych's avatar
chenych committed
111
    print(f"Merged model and processor saved to {save_path}.")
chenych's avatar
chenych committed
112

chenych's avatar
chenych committed
113
    # 3. Copy the extra file from the base model directory to the save_path
chenych's avatar
chenych committed
114
115
116
117
118
119
120
121
122
123
124
    source_file = os.path.join(base_model_path, extra_file)
    target_file = os.path.join(save_path, extra_file)
    if os.path.exists(source_file):
        shutil.copy(source_file, target_file)
        print(f"File '{extra_file}' copied from {base_model_path} to {save_path}.")
    else:
        print(f"File '{extra_file}' not found in {base_model_path}, skipping copy.")


if __name__ == "__main__":
    fire.Fire({"save_full": save_full_model, "merge_lora": merge_lora})