# Copyright 2025 HuggingFace Inc. and the LlamaFactory team. # # This code is inspired by the HuggingFace's transformers library. # https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/summarization/run_summarization.py # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import TYPE_CHECKING, Optional from ...data import get_dataset, get_template_and_fix_tokenizer from ...extras.constants import IGNORE_INDEX from ...extras.logging import get_logger # from ...extras.misc import calculate_tps from ...extras.ploting import plot_loss from ...model import load_model, load_tokenizer from ..trainer_utils import create_modelcard_and_push from .metric import ComputeAccuracy, ComputeSimilarity, eval_logit_processor from trl import GRPOConfig, GRPOTrainer from .data import * from .func import * if TYPE_CHECKING: from transformers import Seq2SeqTrainingArguments, TrainerCallback from ...hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments logger = get_logger(__name__) def run_grpo( model_args: "ModelArguments", data_args: "DataArguments", training_args: "Seq2SeqTrainingArguments", finetuning_args: "FinetuningArguments", generating_args: "GeneratingArguments", callbacks: Optional[list["TrainerCallback"]] = None, ): tokenizer_module = load_tokenizer(model_args) tokenizer = tokenizer_module["tokenizer"] template = get_template_and_fix_tokenizer(tokenizer, data_args) # dataset_module = get_dataset(template, model_args, data_args, training_args, stage="ppo", **tokenizer_module) ## load datasets train_dataset = [] eval_dataset = [] print("training datasets", data_args.dataset) datasets_list = data_args.dataset for indx, dataset in enumerate(datasets_list): dataset = dataset.strip() logger.info("[{}/{}] dealing with {}".format(indx+1, len(datasets_list), dataset)) if "hiyouga-math12k" in dataset: func = get_hiyoga eval_dataset.extend(get_hiyoga(split="test")) elif "openai/gsm8k" in dataset: func = get_gsm8k_questions eval_dataset.extend(get_gsm8k_questions(split="test")) elif "Congliu/Chinese-DeepSeek-R1-Distill-data-110k-SFT" in dataset: func = get_deepseek_r1_questions elif "OpenMathReasoning-mini" in dataset: func = get_unsloth_openmath elif "dapo_math" in dataset: func = get_openr1_dapo_math train_dataset.extend(func()) model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train) grpo_training_args = GRPOConfig( do_train=True, learning_rate=training_args.learning_rate, per_device_train_batch_size=training_args.per_device_train_batch_size, gradient_accumulation_steps=training_args.gradient_accumulation_steps, num_train_epochs=training_args.num_train_epochs, seed=training_args.seed, num_generations=8, lr_scheduler_type=training_args.lr_scheduler_type, adam_beta1=0.9, adam_beta2=0.99, adam_epsilon=1e-08, weight_decay=training_args.weight_decay, warmup_ratio=training_args.warmup_ratio, logging_steps=training_args.logging_steps, bf16=True, save_strategy="steps", save_steps = training_args.save_steps, output_dir=training_args.output_dir, max_prompt_length=1024, max_completion_length=2048, max_grad_norm=0.1, ddp_timeout=1800000, temperature=generating_args.temperature, top_p=generating_args.top_p, top_k=generating_args.top_k, repetition_penalty=generating_args.repetition_penalty, loss_type='grpo', use_vllm=True, vllm_mode="server", # default value, can be omitted vllm_server_base_url=finetuning_args.vllm_server_base_url, report_to="none", deepspeed=training_args.deepspeed ) # Metric utils metric_module = {} if training_args.predict_with_generate: metric_module["compute_metrics"] = ComputeSimilarity(tokenizer=tokenizer) elif finetuning_args.compute_accuracy: metric_module["compute_metrics"] = ComputeAccuracy() metric_module["preprocess_logits_for_metrics"] = eval_logit_processor # Keyword arguments for `model.generate` gen_kwargs = generating_args.to_dict(obey_generation_config=True) gen_kwargs["eos_token_id"] = [tokenizer.eos_token_id] + tokenizer.additional_special_tokens_ids gen_kwargs["pad_token_id"] = tokenizer.pad_token_id trainer = GRPOTrainer( model=model, processing_class=tokenizer, reward_funcs=[ xmlcount_reward_func, soft_format_reward_func, strict_format_reward_func, int_reward_func, correctness_reward_func], args=grpo_training_args, train_dataset=train_dataset, eval_dataset=eval_dataset, peft_config=None, ) # Training if training_args.do_train: train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint) # trainer.save_model() trainer.save_state() trainer.log_metrics("train", train_result.metrics) trainer.save_metrics("train", train_result.metrics) trainer.save_state() if trainer.is_world_process_zero() and finetuning_args.plot_loss: plot_loss(training_args.output_dir, keys=["loss"]) # Create model card create_modelcard_and_push(trainer, model_args, data_args, training_args, finetuning_args)