"vscode:/vscode.git/clone" did not exist on "2e54d5f518a4c7db721483347bec6676a8d20c89"
workflow.py 7.44 KB
Newer Older
chenych's avatar
chenych committed
1
# Copyright 2025 HuggingFace Inc. and the LlamaFactory team.
chenych's avatar
chenych committed
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
#
# This code is inspired by the HuggingFace's transformers library.
# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/summarization/run_summarization.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
17

chenych's avatar
chenych committed
18
from typing import TYPE_CHECKING, Optional
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
19

luopl's avatar
luopl committed
20
from ...data import SFTDataCollatorWith4DAttentionMask, get_dataset, get_template_and_fix_tokenizer
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
21
from ...extras.constants import IGNORE_INDEX
luopl's avatar
luopl committed
22
from ...extras.logging import get_logger
chenych's avatar
chenych committed
23
from ...extras.misc import calculate_tps
shihm's avatar
uodata  
shihm committed
24
from ...extras.packages import is_transformers_version_greater_than
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
25
26
from ...extras.ploting import plot_loss
from ...model import load_model, load_tokenizer
chenych's avatar
chenych committed
27
28
from ..trainer_utils import create_modelcard_and_push
from .metric import ComputeAccuracy, ComputeSimilarity, eval_logit_processor
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
29
30
31
32
33
34
35
36
37
from .trainer import CustomSeq2SeqTrainer


if TYPE_CHECKING:
    from transformers import Seq2SeqTrainingArguments, TrainerCallback

    from ...hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments


luopl's avatar
luopl committed
38
39
40
logger = get_logger(__name__)


Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
41
42
43
44
45
46
def run_sft(
    model_args: "ModelArguments",
    data_args: "DataArguments",
    training_args: "Seq2SeqTrainingArguments",
    finetuning_args: "FinetuningArguments",
    generating_args: "GeneratingArguments",
chenych's avatar
chenych committed
47
    callbacks: Optional[list["TrainerCallback"]] = None,
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
48
):
chenych's avatar
chenych committed
49
50
    tokenizer_module = load_tokenizer(model_args)
    tokenizer = tokenizer_module["tokenizer"]
luopl's avatar
luopl committed
51
52
    template = get_template_and_fix_tokenizer(tokenizer, data_args)
    dataset_module = get_dataset(template, model_args, data_args, training_args, stage="sft", **tokenizer_module)
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
53
54
55
56
57
    model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train)

    if getattr(model, "is_quantized", False) and not training_args.do_train:
        setattr(model, "_hf_peft_config_loaded", True)  # hack here: make model compatible with prediction

chenych's avatar
chenych committed
58
    data_collator = SFTDataCollatorWith4DAttentionMask(
luopl's avatar
luopl committed
59
        template=template,
luopl's avatar
luopl committed
60
        model=model if not training_args.predict_with_generate else None,
chenych's avatar
chenych committed
61
        pad_to_multiple_of=8 if training_args.do_train else None,  # for shift short attention
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
62
        label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id,
chenych's avatar
chenych committed
63
64
65
        block_diag_attn=model_args.block_diag_attn,
        attn_implementation=getattr(model.config, "_attn_implementation", None),
        compute_dtype=model_args.compute_dtype,
luopl's avatar
luopl committed
66
        **tokenizer_module,
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
67
68
    )

chenych's avatar
chenych committed
69
70
    # Metric utils
    metric_module = {}
shihm's avatar
uodata  
shihm committed
71
72
73
74
75
76
    if model_args.use_kt:
        if training_args.predict_with_generate:
            raise NotImplementedError("`predict_with_generate` is not supported in KTransformers SFT yet.")
        elif finetuning_args.compute_accuracy:
            raise NotImplementedError("`compute_accuracy` is not supported in KTransformers SFT yet.")

chenych's avatar
chenych committed
77
78
79
80
81
    if training_args.predict_with_generate:
        metric_module["compute_metrics"] = ComputeSimilarity(tokenizer=tokenizer)
    elif finetuning_args.compute_accuracy:
        metric_module["compute_metrics"] = ComputeAccuracy()
        metric_module["preprocess_logits_for_metrics"] = eval_logit_processor
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
82

chenych's avatar
chenych committed
83
84
    # Keyword arguments for `model.generate`
    gen_kwargs = generating_args.to_dict(obey_generation_config=True)
shihm's avatar
uodata  
shihm committed
85
86
87
88
89
90
91
92
93
94
95
96
97

    # Compatible with Transformers v4 and Transformers v5
    if is_transformers_version_greater_than("4.58.0"):
        extra_ids = getattr(tokenizer, "additional_special_tokens_ids", None)
        if not isinstance(extra_ids, list):
            extra_special_tokens = getattr(tokenizer, "_extra_special_tokens", [])
            string_tokens = [str(t) for t in extra_special_tokens]
            extra_ids = tokenizer.convert_tokens_to_ids(string_tokens)
        all_eos_ids = [tokenizer.eos_token_id] + [i for i in extra_ids if i != -1]
        unique_eos_ids = list(dict.fromkeys(all_eos_ids))
        gen_kwargs["eos_token_id"] = unique_eos_ids
    else:
        gen_kwargs["eos_token_id"] = [tokenizer.eos_token_id] + tokenizer.additional_special_tokens_ids
chenych's avatar
chenych committed
98
99
    gen_kwargs["pad_token_id"] = tokenizer.pad_token_id

Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
100
    # Initialize our Trainer
shihm's avatar
uodata  
shihm committed
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
    if model_args.use_kt:
        from ktransformers.sft.lora import KTrainer  # type: ignore
        from ktransformers.util.globals import GLOBAL_CONFIG  # type: ignore

        GLOBAL_CONFIG._config["mod"] = "sft"

        trainer = KTrainer(
            model=model,
            args=training_args,
            tokenizer=tokenizer_module,
            data_collator=data_collator,
            callbacks=callbacks,
            **dataset_module,
            **metric_module,
        )
        trainer.model_accepts_loss_kwargs = False
        model.config.use_cache = False

    else:
        trainer = CustomSeq2SeqTrainer(
            model=model,
            args=training_args,
            finetuning_args=finetuning_args,
            data_collator=data_collator,
            callbacks=callbacks,
            gen_kwargs=gen_kwargs,
            **dataset_module,
            **tokenizer_module,
            **metric_module,
        )
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
131
132
133
134

    # Training
    if training_args.do_train:
        train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
luopl's avatar
luopl committed
135
        trainer.save_model()
luopl's avatar
luopl committed
136
        if finetuning_args.include_effective_tokens_per_second:
luopl's avatar
luopl committed
137
138
            train_result.metrics["effective_tokens_per_sec"] = calculate_tps(
                dataset_module["train_dataset"], train_result.metrics, stage="sft"
luopl's avatar
luopl committed
139
140
            )

Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
141
142
143
144
        trainer.log_metrics("train", train_result.metrics)
        trainer.save_metrics("train", train_result.metrics)
        trainer.save_state()
        if trainer.is_world_process_zero() and finetuning_args.plot_loss:
chenych's avatar
chenych committed
145
146
147
148
149
150
151
152
153
            keys = ["loss"]
            if isinstance(dataset_module.get("eval_dataset"), dict):
                keys += sum(
                    [[f"eval_{key}_loss", f"eval_{key}_accuracy"] for key in dataset_module["eval_dataset"].keys()], []
                )
            else:
                keys += ["eval_loss", "eval_accuracy"]

            plot_loss(training_args.output_dir, keys=keys)
chenych's avatar
chenych committed
154
155
156

    if training_args.predict_with_generate:
        tokenizer.padding_side = "left"  # use left-padding in generation
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
157
158
159
160
161
162
163
164
165

    # Evaluation
    if training_args.do_eval:
        metrics = trainer.evaluate(metric_key_prefix="eval", **gen_kwargs)
        trainer.log_metrics("eval", metrics)
        trainer.save_metrics("eval", metrics)

    # Predict
    if training_args.do_predict:
luopl's avatar
luopl committed
166
        logger.warning_rank0_once("Batch generation can be very slow. Consider using `scripts/vllm_infer.py` instead.")
chenych's avatar
chenych committed
167
        predict_results = trainer.predict(dataset_module["eval_dataset"], metric_key_prefix="predict", **gen_kwargs)
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
168
169
        trainer.log_metrics("predict", predict_results.metrics)
        trainer.save_metrics("predict", predict_results.metrics)
luopl's avatar
luopl committed
170
        trainer.save_predictions(dataset_module["eval_dataset"], predict_results, generating_args.skip_special_tokens)
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
171
172
173

    # Create model card
    create_modelcard_and_push(trainer, model_args, data_args, training_args, finetuning_args)