workflow.py 6.14 KB
Newer Older
chenych's avatar
chenych committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
# Copyright 2025 HuggingFace Inc. and the LlamaFactory team.
#
# This code is inspired by the HuggingFace's transformers library.
# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/summarization/run_summarization.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import TYPE_CHECKING, Optional
from ...data import get_dataset, get_template_and_fix_tokenizer
from ...extras.constants import IGNORE_INDEX
from ...extras.logging import get_logger
# from ...extras.misc import calculate_tps
from ...extras.ploting import plot_loss
from ...model import load_model, load_tokenizer
from ..trainer_utils import create_modelcard_and_push
from .metric import ComputeAccuracy, ComputeSimilarity, eval_logit_processor

from trl import GRPOConfig, GRPOTrainer
from .data import *
from .func import *
if TYPE_CHECKING:
    from transformers import Seq2SeqTrainingArguments, TrainerCallback

    from ...hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments


logger = get_logger(__name__)


def run_grpo(
    model_args: "ModelArguments",
    data_args: "DataArguments",
    training_args: "Seq2SeqTrainingArguments",
    finetuning_args: "FinetuningArguments",
    generating_args: "GeneratingArguments",
    callbacks: Optional[list["TrainerCallback"]] = None,
):
    tokenizer_module = load_tokenizer(model_args)
    tokenizer = tokenizer_module["tokenizer"]
    template = get_template_and_fix_tokenizer(tokenizer, data_args)
    # dataset_module = get_dataset(template, model_args, data_args, training_args, stage="ppo", **tokenizer_module)
    ## load datasets
    train_dataset = []
    eval_dataset = []
    print("training datasets", data_args.dataset)
    datasets_list = data_args.dataset
    for indx, dataset in enumerate(datasets_list):
        dataset = dataset.strip()

        logger.info("[{}/{}] dealing with {}".format(indx+1, len(datasets_list), dataset))
        if "hiyouga-math12k" in dataset:
            func = get_hiyoga
            eval_dataset.extend(get_hiyoga(split="test"))
        elif "openai/gsm8k" in dataset:
            func = get_gsm8k_questions
            eval_dataset.extend(get_gsm8k_questions(split="test"))
        elif "Congliu/Chinese-DeepSeek-R1-Distill-data-110k-SFT" in dataset:
            func = get_deepseek_r1_questions
        elif "OpenMathReasoning-mini" in dataset:
            func = get_unsloth_openmath
        elif "dapo_math" in dataset:
            func = get_openr1_dapo_math

        train_dataset.extend(func())

    model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train)

    grpo_training_args = GRPOConfig(
            do_train=True,
            learning_rate=training_args.learning_rate,
            per_device_train_batch_size=training_args.per_device_train_batch_size,
            gradient_accumulation_steps=training_args.gradient_accumulation_steps,
            num_train_epochs=training_args.num_train_epochs,
            seed=training_args.seed,
            num_generations=8,
            lr_scheduler_type=training_args.lr_scheduler_type,
            adam_beta1=0.9,
            adam_beta2=0.99,
            adam_epsilon=1e-08,
            weight_decay=training_args.weight_decay,
            warmup_ratio=training_args.warmup_ratio,
            logging_steps=training_args.logging_steps,
            bf16=True,
            save_strategy="steps",
            save_steps = training_args.save_steps,
            output_dir=training_args.output_dir,
            max_prompt_length=1024,
            max_completion_length=2048,
            max_grad_norm=0.1,
            ddp_timeout=1800000,
            temperature=generating_args.temperature,
            top_p=generating_args.top_p,
            top_k=generating_args.top_k,
            repetition_penalty=generating_args.repetition_penalty,
            loss_type='grpo',
            use_vllm=True,
            vllm_mode="server",  # default value, can be omitted
            vllm_server_base_url=finetuning_args.vllm_server_base_url,
            report_to="none",
            deepspeed=training_args.deepspeed
        )


    # Metric utils
    metric_module = {}
    if training_args.predict_with_generate:
        metric_module["compute_metrics"] = ComputeSimilarity(tokenizer=tokenizer)
    elif finetuning_args.compute_accuracy:
        metric_module["compute_metrics"] = ComputeAccuracy()
        metric_module["preprocess_logits_for_metrics"] = eval_logit_processor

    # Keyword arguments for `model.generate`
    gen_kwargs = generating_args.to_dict(obey_generation_config=True)
    gen_kwargs["eos_token_id"] = [tokenizer.eos_token_id] + tokenizer.additional_special_tokens_ids
    gen_kwargs["pad_token_id"] = tokenizer.pad_token_id

    trainer = GRPOTrainer(
        model=model,
        processing_class=tokenizer,
        reward_funcs=[
            xmlcount_reward_func,
            soft_format_reward_func,
            strict_format_reward_func,
            int_reward_func,
            correctness_reward_func],
        args=grpo_training_args,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        peft_config=None,
    )

    # Training
    if training_args.do_train:
        train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
        # trainer.save_model()
        trainer.save_state()

        trainer.log_metrics("train", train_result.metrics)
        trainer.save_metrics("train", train_result.metrics)
        trainer.save_state()
        if trainer.is_world_process_zero() and finetuning_args.plot_loss:
            plot_loss(training_args.output_dir, keys=["loss"])

    # Create model card
    create_modelcard_and_push(trainer, model_args, data_args, training_args, finetuning_args)