qwen_omni_merge.py 6.25 KB
Newer Older
chenych's avatar
chenych committed
1
# Copyright 2025 the LlamaFactory team.
chenych's avatar
chenych committed
2
3
4
5
6
7
8
9
10
11
12
13
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
chenych's avatar
chenych committed
14

chenych's avatar
chenych committed
15
16
17
18
19
20
21
22
23
24
25
26
"""Why we need this script for qwen_omni?

Because the qwen_omni model is constructed by two parts:
1. [Thinker]:[audio_encoder, vision_encoder, LLM backbone], which our repository does support to post-training.
2. [Talker]: [audio_decoder, wave_model], which is not supported to post-training without specific tokenizer.
When we post-training the model, we exactly train the [Thinker] part, and the [Talker] part is dropped.
So, to get the complete model, we need to merge the [Talker] part back to the [Thinker] part.
LoRA mode: [Thinker + LoRA weights] + [Original Talker] -> [Omni model]
Full mode: [Thinker] + [Original Talker] -> [Omni model]
For Processor, we do saved the processor from trained model instead of the original model.
"""

chenych's avatar
chenych committed
27
28
29
30
31
import os
import shutil

import fire
from peft import PeftModel
shihm's avatar
uodata  
shihm committed
32
33
from transformers import AutoConfig, AutoModelForTextToWaveform, AutoProcessor
from transformers.utils import cached_file
chenych's avatar
chenych committed
34
35
36


def merge_lora(
shihm's avatar
uodata  
shihm committed
37
38
39
    model_path: str,
    lora_path: str,
    save_path: str = "./merged_model_checkpoint",
chenych's avatar
chenych committed
40
41
42
    extra_file: str = "spk_dict.pt",
    submodule_name: str = "thinker",
):
chenych's avatar
chenych committed
43
    """Load the original model, merge the LoRA weights.
chenych's avatar
chenych committed
44
45
46
47

    For a specified submodule, and save the final merged model along with its configurations.

    Args:
shihm's avatar
uodata  
shihm committed
48
49
50
        model_path (str): Path to the original model directory.
        lora_path (str): Path to the directory containing LoRA weights.
        save_path (str): Directory where the merged model and configurations will be saved.
chenych's avatar
chenych committed
51
52
53
        extra_file (str): Name of the extra file to be copied (default: "spk_dict.pt").
        submodule_name (str): Name of the submodule to merge (default: "thinker").
    """
chenych's avatar
chenych committed
54
    # 1. Load the original model
shihm's avatar
uodata  
shihm committed
55
    model = AutoModelForTextToWaveform.from_pretrained(model_path, torch_dtype="auto", device_map="cpu")
chenych's avatar
chenych committed
56
    print("Successfully loaded the original model.")
chenych's avatar
chenych committed
57
58
59
60

    # 2. Extract the submodule to be merged (e.g., model.thinker)
    if not hasattr(model, submodule_name):
        raise AttributeError(f"The model does not have a submodule named '{submodule_name}'.")
chenych's avatar
chenych committed
61

chenych's avatar
chenych committed
62
63
64
65
    base_submodule = getattr(model, submodule_name)
    print(f"Successfully extracted submodule: {submodule_name}.")

    # 3. Load the LoRA weights onto the extracted submodule
shihm's avatar
uodata  
shihm committed
66
67
68
    lora_model = PeftModel.from_pretrained(base_submodule, lora_path)
    processor = AutoProcessor.from_pretrained(lora_path)
    print("Successfully loaded LoRA weights and processor.")
chenych's avatar
chenych committed
69
70
71

    # 4. Merge the LoRA weights into the submodule and unload the LoRA modules
    merged_submodule = lora_model.merge_and_unload()
shihm's avatar
uodata  
shihm committed
72
    print("Successfully merged LoRA weights.")
chenych's avatar
chenych committed
73
74
75
76
77
78

    # 5. Replace the original submodule with the merged submodule in the model
    setattr(model, submodule_name, merged_submodule)

    # 6. Save the final merged model along with the tokenizer and processor configuration
    model.save_pretrained(save_path)
chenych's avatar
chenych committed
79
    processor.save_pretrained(save_path)
shihm's avatar
uodata  
shihm committed
80
    print(f"Merged model and processor saved to {save_path}.")
chenych's avatar
chenych committed
81

shihm's avatar
uodata  
shihm committed
82
83
84
85
86
87
    try:
        source_file = cached_file(path_or_repo_id=model_path, filename=extra_file)
        shutil.copy(source_file, os.path.join(save_path, extra_file))
        print(f"File '{extra_file}' copied from {model_path} to {save_path}.")
    except Exception:
        print(f"File '{extra_file}' not found in {model_path}, skipping copy.")
chenych's avatar
chenych committed
88
89
90


def save_full_model(
shihm's avatar
uodata  
shihm committed
91
92
    model_path: str,
    thinker_path: str,
chenych's avatar
chenych committed
93
    save_path: str = "./merged_model_checkpoint",
chenych's avatar
chenych committed
94
95
96
97
98
99
100
    extra_file: str = "spk_dict.pt",
):
    """Load the saved thinker module and the original model, replace the thinker in the original model.

    Then save the complete model along with its tokenizer and processor configuration.

    Args:
shihm's avatar
uodata  
shihm committed
101
102
        model_path (str): Directory path of the original model.
        thinker_path (str): Path to the saved thinker weights.
chenych's avatar
chenych committed
103
        save_path (str): Directory where the merged model and configurations will be saved.
chenych's avatar
chenych committed
104
105
        extra_file (str): Name of the extra file to be copied (default: "spk_dict.pt").
    """
chenych's avatar
chenych committed
106
    # 1. Load the saved thinker module and the original model
shihm's avatar
uodata  
shihm committed
107
108
109
110
111
112
113
114
115
116
117
118
119
120
    config = AutoConfig.from_pretrained(model_path)
    if getattr(config, "model_type") == "qwen2_5_omni":
        from transformers.models.qwen2_5_omni import Qwen2_5OmniThinkerForConditionalGeneration  # type: ignore

        ThinkerClass = Qwen2_5OmniThinkerForConditionalGeneration
    elif getattr(config, "model_type") == "qwen3_omni_moe":
        from transformers.models.qwen3_omni_moe import Qwen3OmniMoeThinkerForConditionalGeneration  # type: ignore

        ThinkerClass = Qwen3OmniMoeThinkerForConditionalGeneration
    else:
        raise ValueError(f"Unsupported model type: {getattr(config, 'model_type')}.")

    thinker = ThinkerClass.from_pretrained(thinker_path, torch_dtype="auto", device_map="cpu")
    base_model = AutoModelForTextToWaveform.from_pretrained(model_path, torch_dtype="auto", device_map="cpu")
chenych's avatar
chenych committed
121
    base_model.thinker = thinker
shihm's avatar
uodata  
shihm committed
122
123
    processor = AutoProcessor.from_pretrained(thinker_path)
    print("Successfully loaded model weights and processor.")
chenych's avatar
chenych committed
124

chenych's avatar
chenych committed
125
    # 2. Save the complete model along with its tokenizer and processor configuration
chenych's avatar
chenych committed
126
127
    base_model.save_pretrained(save_path)
    processor.save_pretrained(save_path)
chenych's avatar
chenych committed
128
    print(f"Merged model and processor saved to {save_path}.")
chenych's avatar
chenych committed
129

chenych's avatar
chenych committed
130
    # 3. Copy the extra file from the base model directory to the save_path
shihm's avatar
uodata  
shihm committed
131
132
133
134
135
136
    try:
        source_file = cached_file(path_or_repo_id=model_path, filename=extra_file)
        shutil.copy(source_file, os.path.join(save_path, extra_file))
        print(f"File '{extra_file}' copied from {model_path} to {save_path}.")
    except Exception:
        print(f"File '{extra_file}' not found in {model_path}, skipping copy.")
chenych's avatar
chenych committed
137
138
139
140


if __name__ == "__main__":
    fire.Fire({"save_full": save_full_model, "merge_lora": merge_lora})