qwen_omni_merge.py 5.84 KB
Newer Older
chenych's avatar
chenych committed
1
# Copyright 2025 the LlamaFactory team.
chenych's avatar
chenych committed
2
3
4
5
6
7
8
9
10
11
12
13
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
chenych's avatar
chenych committed
14

chenych's avatar
chenych committed
15
16
17
18
19
20
21
22
23
24
25
26
"""Why we need this script for qwen_omni?

Because the qwen_omni model is constructed by two parts:
1. [Thinker]:[audio_encoder, vision_encoder, LLM backbone], which our repository does support to post-training.
2. [Talker]: [audio_decoder, wave_model], which is not supported to post-training without specific tokenizer.
When we post-training the model, we exactly train the [Thinker] part, and the [Talker] part is dropped.
So, to get the complete model, we need to merge the [Talker] part back to the [Thinker] part.
LoRA mode: [Thinker + LoRA weights] + [Original Talker] -> [Omni model]
Full mode: [Thinker] + [Original Talker] -> [Omni model]
For Processor, we do saved the processor from trained model instead of the original model.
"""

chenych's avatar
chenych committed
27
28
29
30
31
import os
import shutil

import fire
from peft import PeftModel
chenych's avatar
chenych committed
32
33
34
35
36
from transformers import (
    AutoProcessor,
    Qwen2_5OmniForConditionalGeneration,  # type: ignore
    Qwen2_5OmniThinkerForConditionalGeneration,
)
chenych's avatar
chenych committed
37
38
39
40
41
42
43
44
45


def merge_lora(
    base_model_path: str,
    lora_checkpoint_path: str,
    extra_file: str = "spk_dict.pt",
    submodule_name: str = "thinker",
    save_path: str = "./merged_model_checkpoint",
):
chenych's avatar
chenych committed
46
    """Load the original model, merge the LoRA weights.
chenych's avatar
chenych committed
47
48
49
50
51
52
53
54
55
56

    For a specified submodule, and save the final merged model along with its configurations.

    Args:
        base_model_path (str): Path to the original model directory.
        lora_checkpoint_path (str): Path to the directory containing LoRA weights.
        extra_file (str): Name of the extra file to be copied (default: "spk_dict.pt").
        submodule_name (str): Name of the submodule to merge (default: "thinker").
        save_path (str): Directory where the merged model and configurations will be saved.
    """
chenych's avatar
chenych committed
57
58
59
    # 1. Load the original model
    model = Qwen2_5OmniForConditionalGeneration.from_pretrained(base_model_path, torch_dtype="auto", device_map="cpu")
    print("Successfully loaded the original model.")
chenych's avatar
chenych committed
60
61
62
63

    # 2. Extract the submodule to be merged (e.g., model.thinker)
    if not hasattr(model, submodule_name):
        raise AttributeError(f"The model does not have a submodule named '{submodule_name}'.")
chenych's avatar
chenych committed
64

chenych's avatar
chenych committed
65
66
67
68
69
    base_submodule = getattr(model, submodule_name)
    print(f"Successfully extracted submodule: {submodule_name}.")

    # 3. Load the LoRA weights onto the extracted submodule
    lora_model = PeftModel.from_pretrained(base_submodule, lora_checkpoint_path)
chenych's avatar
chenych committed
70
71
    processor = AutoProcessor.from_pretrained(lora_checkpoint_path)
    print("LoRA weights and processor loaded successfully.")
chenych's avatar
chenych committed
72
73
74
75
76
77
78
79
80
81

    # 4. Merge the LoRA weights into the submodule and unload the LoRA modules
    merged_submodule = lora_model.merge_and_unload()
    print("LoRA weights merged successfully.")

    # 5. Replace the original submodule with the merged submodule in the model
    setattr(model, submodule_name, merged_submodule)

    # 6. Save the final merged model along with the tokenizer and processor configuration
    model.save_pretrained(save_path)
chenych's avatar
chenych committed
82
83
    processor.save_pretrained(save_path)
    print(f"Merged model and tokenizer saved to {save_path}.")
chenych's avatar
chenych committed
84
85
86
87
88
89
90
91
92
93
94
95
96

    source_file = os.path.join(base_model_path, extra_file)
    target_file = os.path.join(save_path, extra_file)
    if os.path.exists(source_file):
        shutil.copy(source_file, target_file)
        print(f"File '{extra_file}' copied from {base_model_path} to {save_path}.")
    else:
        print(f"File '{extra_file}' not found in {base_model_path}, skipping copy.")


def save_full_model(
    saved_thinker_path: str,
    base_model_path: str,
chenych's avatar
chenych committed
97
    save_path: str = "./merged_model_checkpoint",
chenych's avatar
chenych committed
98
99
100
101
102
103
104
105
106
    extra_file: str = "spk_dict.pt",
):
    """Load the saved thinker module and the original model, replace the thinker in the original model.

    Then save the complete model along with its tokenizer and processor configuration.

    Args:
        saved_thinker_path (str): Path to the saved thinker weights.
        base_model_path (str): Directory path of the original model.
chenych's avatar
chenych committed
107
        save_path (str): Directory where the merged model and configurations will be saved.
chenych's avatar
chenych committed
108
109
        extra_file (str): Name of the extra file to be copied (default: "spk_dict.pt").
    """
chenych's avatar
chenych committed
110
111
112
113
    # 1. Load the saved thinker module and the original model
    thinker = Qwen2_5OmniThinkerForConditionalGeneration.from_pretrained(
        saved_thinker_path, torch_dtype="auto", device_map="cpu"
    )
chenych's avatar
chenych committed
114
115
116
    base_model = Qwen2_5OmniForConditionalGeneration.from_pretrained(
        base_model_path, torch_dtype="auto", device_map="cpu"
    )
chenych's avatar
chenych committed
117
118
    base_model.thinker = thinker

chenych's avatar
chenych committed
119
    # 2. Save the complete model along with its tokenizer and processor configuration
chenych's avatar
chenych committed
120
    processor = AutoProcessor.from_pretrained(saved_thinker_path)
chenych's avatar
chenych committed
121
122
    base_model.save_pretrained(save_path)
    processor.save_pretrained(save_path)
chenych's avatar
chenych committed
123
    print(f"Merged model and processor saved to {save_path}.")
chenych's avatar
chenych committed
124

chenych's avatar
chenych committed
125
    # 3. Copy the extra file from the base model directory to the save_path
chenych's avatar
chenych committed
126
127
128
129
130
131
132
133
134
135
136
    source_file = os.path.join(base_model_path, extra_file)
    target_file = os.path.join(save_path, extra_file)
    if os.path.exists(source_file):
        shutil.copy(source_file, target_file)
        print(f"File '{extra_file}' copied from {base_model_path} to {save_path}.")
    else:
        print(f"File '{extra_file}' not found in {base_model_path}, skipping copy.")


if __name__ == "__main__":
    fire.Fire({"save_full": save_full_model, "merge_lora": merge_lora})