tuner.py 5.93 KB
Newer Older
chenych's avatar
chenych committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import shutil
from typing import TYPE_CHECKING, Any, Dict, List, Optional

import torch
from transformers import PreTrainedModel

from ..data import get_template_and_fix_tokenizer
luopl's avatar
luopl committed
23
from ..extras import logging
chenych's avatar
chenych committed
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
from ..extras.constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
from ..hparams import get_infer_args, get_train_args
from ..model import load_model, load_tokenizer
from .callbacks import LogCallback
from .dpo import run_dpo
from .kto import run_kto
from .ppo import run_ppo
from .pt import run_pt
from .rm import run_rm
from .sft import run_sft


if TYPE_CHECKING:
    from transformers import TrainerCallback


luopl's avatar
luopl committed
40
logger = logging.get_logger(__name__)
chenych's avatar
chenych committed
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59


def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: List["TrainerCallback"] = []) -> None:
    callbacks.append(LogCallback())
    model_args, data_args, training_args, finetuning_args, generating_args = get_train_args(args)

    if finetuning_args.stage == "pt":
        run_pt(model_args, data_args, training_args, finetuning_args, callbacks)
    elif finetuning_args.stage == "sft":
        run_sft(model_args, data_args, training_args, finetuning_args, generating_args, callbacks)
    elif finetuning_args.stage == "rm":
        run_rm(model_args, data_args, training_args, finetuning_args, callbacks)
    elif finetuning_args.stage == "ppo":
        run_ppo(model_args, data_args, training_args, finetuning_args, generating_args, callbacks)
    elif finetuning_args.stage == "dpo":
        run_dpo(model_args, data_args, training_args, finetuning_args, callbacks)
    elif finetuning_args.stage == "kto":
        run_kto(model_args, data_args, training_args, finetuning_args, callbacks)
    else:
luopl's avatar
luopl committed
60
        raise ValueError(f"Unknown task: {finetuning_args.stage}.")
chenych's avatar
chenych committed
61
62
63
64
65
66
67
68
69
70
71
72
73
74


def export_model(args: Optional[Dict[str, Any]] = None) -> None:
    model_args, data_args, finetuning_args, _ = get_infer_args(args)

    if model_args.export_dir is None:
        raise ValueError("Please specify `export_dir` to save model.")

    if model_args.adapter_name_or_path is not None and model_args.export_quantization_bit is not None:
        raise ValueError("Please merge adapters before quantizing the model.")

    tokenizer_module = load_tokenizer(model_args)
    tokenizer = tokenizer_module["tokenizer"]
    processor = tokenizer_module["processor"]
luopl's avatar
luopl committed
75
    get_template_and_fix_tokenizer(tokenizer, data_args)
chenych's avatar
chenych committed
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
    model = load_model(tokenizer, model_args, finetuning_args)  # must after fixing tokenizer to resize vocab

    if getattr(model, "quantization_method", None) is not None and model_args.adapter_name_or_path is not None:
        raise ValueError("Cannot merge adapters to a quantized model.")

    if not isinstance(model, PreTrainedModel):
        raise ValueError("The model is not a `PreTrainedModel`, export aborted.")

    if getattr(model, "quantization_method", None) is not None:  # quantized model adopts float16 type
        setattr(model.config, "torch_dtype", torch.float16)
    else:
        if model_args.infer_dtype == "auto":
            output_dtype = getattr(model.config, "torch_dtype", torch.float16)
        else:
            output_dtype = getattr(torch, model_args.infer_dtype)

        setattr(model.config, "torch_dtype", output_dtype)
        model = model.to(output_dtype)
luopl's avatar
luopl committed
94
        logger.info_rank0(f"Convert model dtype to: {output_dtype}.")
chenych's avatar
chenych committed
95
96
97

    model.save_pretrained(
        save_directory=model_args.export_dir,
luopl's avatar
luopl committed
98
        max_shard_size=f"{model_args.export_size}GB",
chenych's avatar
chenych committed
99
100
101
102
103
104
        safe_serialization=(not model_args.export_legacy_format),
    )
    if model_args.export_hub_model_id is not None:
        model.push_to_hub(
            model_args.export_hub_model_id,
            token=model_args.hf_hub_token,
luopl's avatar
luopl committed
105
            max_shard_size=f"{model_args.export_size}GB",
chenych's avatar
chenych committed
106
107
108
109
110
111
112
113
114
115
116
117
118
119
            safe_serialization=(not model_args.export_legacy_format),
        )

    if finetuning_args.stage == "rm":
        if model_args.adapter_name_or_path is not None:
            vhead_path = model_args.adapter_name_or_path[-1]
        else:
            vhead_path = model_args.model_name_or_path

        if os.path.exists(os.path.join(vhead_path, V_HEAD_SAFE_WEIGHTS_NAME)):
            shutil.copy(
                os.path.join(vhead_path, V_HEAD_SAFE_WEIGHTS_NAME),
                os.path.join(model_args.export_dir, V_HEAD_SAFE_WEIGHTS_NAME),
            )
luopl's avatar
luopl committed
120
            logger.info_rank0(f"Copied valuehead to {model_args.export_dir}.")
chenych's avatar
chenych committed
121
122
123
124
125
        elif os.path.exists(os.path.join(vhead_path, V_HEAD_WEIGHTS_NAME)):
            shutil.copy(
                os.path.join(vhead_path, V_HEAD_WEIGHTS_NAME),
                os.path.join(model_args.export_dir, V_HEAD_WEIGHTS_NAME),
            )
luopl's avatar
luopl committed
126
            logger.info_rank0(f"Copied valuehead to {model_args.export_dir}.")
chenych's avatar
chenych committed
127
128
129
130
131
132
133
134

    try:
        tokenizer.padding_side = "left"  # restore padding side
        tokenizer.init_kwargs["padding_side"] = "left"
        tokenizer.save_pretrained(model_args.export_dir)
        if model_args.export_hub_model_id is not None:
            tokenizer.push_to_hub(model_args.export_hub_model_id, token=model_args.hf_hub_token)

luopl's avatar
luopl committed
135
        if processor is not None:
luopl's avatar
luopl committed
136
            processor.save_pretrained(model_args.export_dir)
chenych's avatar
chenych committed
137
            if model_args.export_hub_model_id is not None:
luopl's avatar
luopl committed
138
                processor.push_to_hub(model_args.export_hub_model_id, token=model_args.hf_hub_token)
chenych's avatar
chenych committed
139

luopl's avatar
luopl committed
140
    except Exception as e:
luopl's avatar
luopl committed
141
        logger.warning_rank0(f"Cannot save tokenizer, please copy the files manually: {e}.")