export.py 5.65 KB
Newer Older
chenych's avatar
chenych committed
1
# Copyright 2025 the LlamaFactory team.
chenych's avatar
chenych committed
2
3
4
5
6
7
8
9
10
11
12
13
14
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

chenych's avatar
chenych committed
15
import json
chenych's avatar
chenych committed
16
17
from collections.abc import Generator
from typing import TYPE_CHECKING, Union
chenych's avatar
chenych committed
18
19
20

from ...extras.constants import PEFT_METHODS
from ...extras.misc import torch_gc
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
21
from ...extras.packages import is_gradio_available
chenych's avatar
chenych committed
22
from ...train.tuner import export_model
chenych's avatar
chenych committed
23
from ..common import get_save_dir, load_config
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
24
25
26
27
28
29
30
31
32
33
34
35
36
from ..locales import ALERTS


if is_gradio_available():
    import gradio as gr


if TYPE_CHECKING:
    from gradio.components import Component

    from ..engine import Engine


chenych's avatar
chenych committed
37
38
39
GPTQ_BITS = ["8", "4", "3", "2"]


chenych's avatar
chenych committed
40
def can_quantize(checkpoint_path: Union[str, list[str]]) -> "gr.Dropdown":
chenych's avatar
chenych committed
41
42
43
44
    if isinstance(checkpoint_path, list) and len(checkpoint_path) != 0:
        return gr.Dropdown(value="none", interactive=False)
    else:
        return gr.Dropdown(interactive=True)
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
45
46
47
48
49
50
51


def save_model(
    lang: str,
    model_name: str,
    model_path: str,
    finetuning_type: str,
chenych's avatar
chenych committed
52
    checkpoint_path: Union[str, list[str]],
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
53
    template: str,
chenych's avatar
chenych committed
54
55
    export_size: int,
    export_quantization_bit: str,
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
56
    export_quantization_dataset: str,
chenych's avatar
chenych committed
57
    export_device: str,
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
58
59
60
    export_legacy_format: bool,
    export_dir: str,
    export_hub_model_id: str,
chenych's avatar
chenych committed
61
    extra_args: str,
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
62
) -> Generator[str, None, None]:
chenych's avatar
chenych committed
63
    user_config = load_config()
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
64
65
66
67
68
69
70
71
72
    error = ""
    if not model_name:
        error = ALERTS["err_no_model"][lang]
    elif not model_path:
        error = ALERTS["err_no_path"][lang]
    elif not export_dir:
        error = ALERTS["err_no_export_dir"][lang]
    elif export_quantization_bit in GPTQ_BITS and not export_quantization_dataset:
        error = ALERTS["err_no_dataset"][lang]
chenych's avatar
chenych committed
73
    elif export_quantization_bit not in GPTQ_BITS and not checkpoint_path:
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
74
        error = ALERTS["err_no_adapter"][lang]
chenych's avatar
chenych committed
75
76
    elif export_quantization_bit in GPTQ_BITS and checkpoint_path and isinstance(checkpoint_path, list):
        error = ALERTS["err_gptq_lora"][lang]
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
77

chenych's avatar
chenych committed
78
79
80
81
82
    try:
        json.loads(extra_args)
    except json.JSONDecodeError:
        error = ALERTS["err_json_schema"][lang]

Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
83
84
85
86
87
88
89
    if error:
        gr.Warning(error)
        yield error
        return

    args = dict(
        model_name_or_path=model_path,
chenych's avatar
chenych committed
90
        cache_dir=user_config.get("cache_dir", None),
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
91
92
93
94
        finetuning_type=finetuning_type,
        template=template,
        export_dir=export_dir,
        export_hub_model_id=export_hub_model_id or None,
chenych's avatar
chenych committed
95
        export_size=export_size,
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
96
97
        export_quantization_bit=int(export_quantization_bit) if export_quantization_bit in GPTQ_BITS else None,
        export_quantization_dataset=export_quantization_dataset,
chenych's avatar
chenych committed
98
        export_device=export_device,
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
99
        export_legacy_format=export_legacy_format,
luopl's avatar
luopl committed
100
        trust_remote_code=True,
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
101
    )
chenych's avatar
chenych committed
102
    args.update(json.loads(extra_args))
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
103

chenych's avatar
chenych committed
104
105
106
107
108
109
110
111
    if checkpoint_path:
        if finetuning_type in PEFT_METHODS:  # list
            args["adapter_name_or_path"] = ",".join(
                [get_save_dir(model_name, finetuning_type, adapter) for adapter in checkpoint_path]
            )
        else:  # str
            args["model_name_or_path"] = get_save_dir(model_name, finetuning_type, checkpoint_path)

Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
112
113
    yield ALERTS["info_exporting"][lang]
    export_model(args)
chenych's avatar
chenych committed
114
    torch_gc()
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
115
116
117
    yield ALERTS["info_exported"][lang]


chenych's avatar
chenych committed
118
def create_export_tab(engine: "Engine") -> dict[str, "Component"]:
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
119
    with gr.Row():
chenych's avatar
chenych committed
120
121
        export_size = gr.Slider(minimum=1, maximum=100, value=5, step=1)
        export_quantization_bit = gr.Dropdown(choices=["none"] + GPTQ_BITS, value="none")
chenych's avatar
chenych committed
122
        export_quantization_dataset = gr.Textbox(value="data/c4_demo.jsonl")
chenych's avatar
chenych committed
123
        export_device = gr.Radio(choices=["cpu", "auto"], value="cpu")
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
124
125
126
127
128
        export_legacy_format = gr.Checkbox()

    with gr.Row():
        export_dir = gr.Textbox()
        export_hub_model_id = gr.Textbox()
chenych's avatar
chenych committed
129
        extra_args = gr.Textbox(value="{}")
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
130

chenych's avatar
chenych committed
131
132
133
    checkpoint_path: gr.Dropdown = engine.manager.get_elem_by_id("top.checkpoint_path")
    checkpoint_path.change(can_quantize, [checkpoint_path], [export_quantization_bit], queue=False)

Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
134
135
136
137
138
139
140
141
142
143
    export_btn = gr.Button()
    info_box = gr.Textbox(show_label=False, interactive=False)

    export_btn.click(
        save_model,
        [
            engine.manager.get_elem_by_id("top.lang"),
            engine.manager.get_elem_by_id("top.model_name"),
            engine.manager.get_elem_by_id("top.model_path"),
            engine.manager.get_elem_by_id("top.finetuning_type"),
chenych's avatar
chenych committed
144
            engine.manager.get_elem_by_id("top.checkpoint_path"),
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
145
            engine.manager.get_elem_by_id("top.template"),
chenych's avatar
chenych committed
146
            export_size,
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
147
148
            export_quantization_bit,
            export_quantization_dataset,
chenych's avatar
chenych committed
149
            export_device,
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
150
151
152
            export_legacy_format,
            export_dir,
            export_hub_model_id,
chenych's avatar
chenych committed
153
            extra_args,
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
154
155
156
157
158
        ],
        [info_box],
    )

    return dict(
chenych's avatar
chenych committed
159
        export_size=export_size,
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
160
161
        export_quantization_bit=export_quantization_bit,
        export_quantization_dataset=export_quantization_dataset,
chenych's avatar
chenych committed
162
        export_device=export_device,
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
163
164
165
        export_legacy_format=export_legacy_format,
        export_dir=export_dir,
        export_hub_model_id=export_hub_model_id,
chenych's avatar
chenych committed
166
        extra_args=extra_args,
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
167
168
169
        export_btn=export_btn,
        info_box=info_box,
    )