# Copyright 2024 Bytedance Ltd. and/or its affiliates # Copyright 2023-2024 SGLang Team # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ FSDP PPO Trainer with Ray-based single controller. This trainer supports model-agonistic model initialization with huggingface """ import uuid from copy import deepcopy from pprint import pprint from typing import Optional import numpy as np import ray import torch from torch.utils.data import Dataset, Sampler from tqdm import tqdm from verl import DataProto from verl.single_controller.ray import RayWorkerGroup from verl.trainer.ppo import core_algos from verl.trainer.ppo.core_algos import agg_loss from verl.trainer.ppo.metric_utils import reduce_metrics from verl.trainer.ppo.ray_trainer import ( AdvantageEstimator, RayPPOTrainer, ResourcePoolManager, Role, WorkerType, apply_kl_penalty, compute_response_mask, ) from verl.trainer.ppo.reward import compute_reward, compute_reward_async from verl.utils.profiler.performance import simple_timer from verl.utils.tracking import ValidationGenerationsLogger def softmean(x: torch.Tensor, beta: float, dim: int = -1, keepdim: bool = False) -> torch.Tensor: """ Compute SoftMean_β(x) = (1/β) * log( (1/n) * Σ exp(β * x_i) ) Falls back to arithmetic mean when β=0. """ if beta == 0.0: return x.mean(dim=dim, keepdim=keepdim) # cast beta to tensor on same device/dtype beta_t = x.new_tensor(beta) # numerically-stable logsumexp(β x) lse = torch.logsumexp(x * beta_t, dim=dim, keepdim=keepdim) n = x.size(dim) log_n = x.new_tensor(n).log() return (lse - log_n) / beta_t def compute_advantage(data: DataProto, beta=1.0): rewards = data.batch["token_level_rewards"].sum(axis=-1) # (bs, ) s_mean = softmean(rewards, beta, keepdim=True) # (bs, ) rewards = rewards - s_mean # (bs, ) data.batch["seq_level_rewards"] = rewards # (bs, ) return data class RaySPPOTrainer(RayPPOTrainer): """ Note that this trainer runs on the driver process on a single CPU/GPU node. """ # TODO: support each role have individual ray_worker_group_cls, # i.e., support different backend of different role def __init__( self, config, tokenizer, role_worker_mapping: dict[Role, WorkerType], resource_pool_manager: ResourcePoolManager, ray_worker_group_cls: RayWorkerGroup = RayWorkerGroup, processor=None, reward_fn=None, val_reward_fn=None, train_dataset: Optional[Dataset] = None, val_dataset: Optional[Dataset] = None, collate_fn=None, train_sampler: Optional[Sampler] = None, device_name=None, ): self.tokenizer = tokenizer self.processor = processor self.config = config self.reward_fn = reward_fn self.val_reward_fn = val_reward_fn self.hybrid_engine = config.actor_rollout_ref.hybrid_engine assert self.hybrid_engine, "Currently, only support hybrid engine" if self.hybrid_engine: assert Role.ActorRollout in role_worker_mapping, f"{role_worker_mapping.keys()=}" self.role_worker_mapping = role_worker_mapping self.resource_pool_manager = resource_pool_manager self.use_reference_policy = Role.RefPolicy in role_worker_mapping self.use_rm = Role.RewardModel in role_worker_mapping self.ray_worker_group_cls = ray_worker_group_cls self.validation_generations_logger = ValidationGenerationsLogger() self.device_name = device_name if device_name else self.config.trainer.device # define in-reward KL control # kl loss control currently not supported if config.algorithm.use_kl_in_reward: self.kl_ctrl_in_reward = core_algos.get_kl_controller(config.algorithm.kl_ctrl) self.use_critic = False self._validate_config() self._create_dataloader(train_dataset, val_dataset, collate_fn, train_sampler) def fit(self): """ The training loop of PPO. The driver process only need to call the compute functions of the worker group through RPC to construct the PPO dataflow. The light-weight advantage computation is done on the driver process. """ from omegaconf import OmegaConf from verl.utils.tracking import Tracking logger = Tracking( project_name=self.config.trainer.project_name, experiment_name=self.config.trainer.experiment_name, default_backend=self.config.trainer.logger, config=OmegaConf.to_container(self.config, resolve=True), ) self.global_steps = 0 # load checkpoint before doing anything self._load_checkpoint() # perform validation before training # currently, we only support validation using the reward_function. if self.val_reward_fn is not None and self.config.trainer.get("val_before_train", True): val_metrics = self._validate() pprint(f"Initial validation metrics: {val_metrics}") logger.log(data=val_metrics, step=self.global_steps) if self.config.trainer.get("val_only", False): return # add tqdm progress_bar = tqdm(total=self.total_training_steps, initial=self.global_steps, desc="Training Progress") # we start from step 1 self.global_steps += 1 last_val_metrics = None for epoch in range(self.config.trainer.total_epochs): for batch_dict in self.train_dataloader: metrics = {} timing_raw = {} batch: DataProto = DataProto.from_single_dict(batch_dict) # pop those keys for generation batch_keys_to_pop = ["input_ids", "attention_mask", "position_ids"] non_tensor_batch_keys_to_pop = ["raw_prompt_ids"] if "multi_modal_data" in batch.non_tensor_batch: non_tensor_batch_keys_to_pop.append("multi_modal_data") if "raw_prompt" in batch.non_tensor_batch: non_tensor_batch_keys_to_pop.append("raw_prompt") if "tools_kwargs" in batch.non_tensor_batch: non_tensor_batch_keys_to_pop.append("tools_kwargs") gen_batch = batch.pop( batch_keys=batch_keys_to_pop, non_tensor_batch_keys=non_tensor_batch_keys_to_pop, ) gen_batch = gen_batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True) is_last_step = self.global_steps >= self.total_training_steps with simple_timer("step", timing_raw): # generate a batch with simple_timer("gen", timing_raw): if not self.async_rollout_mode: gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch) else: gen_batch_output = self.async_rollout_manager.generate_sequences(gen_batch) timing_raw.update(gen_batch_output.meta_info["timing"]) gen_batch_output.meta_info.pop("timing", None) if self.config.algorithm.adv_estimator == AdvantageEstimator.REMAX: with simple_timer("gen_max", timing_raw): gen_baseline_batch = deepcopy(gen_batch) gen_baseline_batch.meta_info["do_sample"] = False gen_baseline_output = self.actor_rollout_wg.generate_sequences(gen_baseline_batch) batch = batch.union(gen_baseline_output) reward_baseline_tensor = self.reward_fn(batch) reward_baseline_tensor = reward_baseline_tensor.sum(dim=-1) batch.pop(batch_keys=list(gen_baseline_output.batch.keys())) batch.batch["reward_baselines"] = reward_baseline_tensor del gen_baseline_batch, gen_baseline_output batch.non_tensor_batch["uid"] = np.array( [str(uuid.uuid4()) for _ in range(len(batch.batch))], dtype=object ) # repeat to align with repeated responses in rollout batch = batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True) batch = batch.union(gen_batch_output) batch.batch["response_mask"] = compute_response_mask(batch) # Balance the number of valid tokens across DP ranks. # NOTE: This usually changes the order of data in the `batch`, # which won't affect the advantage calculation (since it's based on uid), # but might affect the loss calculation (due to the change of mini-batching). # TODO: Decouple the DP balancing and mini-batching. if self.config.trainer.balance_batch: self._balance_batch(batch, metrics=metrics) # compute global_valid tokens batch.meta_info["global_token_num"] = torch.sum(batch.batch["attention_mask"], dim=-1).tolist() with simple_timer("reward", timing_raw): # compute reward model score if self.use_rm: reward_tensor = self.rm_wg.compute_rm_score(batch) batch = batch.union(reward_tensor) if self.config.reward_model.launch_reward_fn_async: future_reward = compute_reward_async.remote(batch, self.config, self.tokenizer) else: reward_tensor, reward_extra_infos_dict = compute_reward(batch, self.reward_fn) # recompute old_log_probs with simple_timer("old_log_prob", timing_raw): old_log_prob = self.actor_rollout_wg.compute_log_prob(batch) entropys = old_log_prob.batch["entropys"] response_masks = batch.batch["response_mask"] loss_agg_mode = self.config.actor_rollout_ref.actor.loss_agg_mode entropy_agg = agg_loss(loss_mat=entropys, loss_mask=response_masks, loss_agg_mode=loss_agg_mode) old_log_prob_metrics = {"actor/entropy": entropy_agg.detach().item()} metrics.update(old_log_prob_metrics) old_log_prob.batch.pop("entropys") batch = batch.union(old_log_prob) if self.use_reference_policy: # compute reference log_prob with simple_timer("ref", timing_raw): ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch) batch = batch.union(ref_log_prob) # compute values if self.use_critic: with simple_timer("values", timing_raw): values = self.critic_wg.compute_values(batch) batch = batch.union(values) with simple_timer("adv", timing_raw): # we combine with rule-based rm reward_extra_infos_dict: dict[str, list] if self.config.reward_model.launch_reward_fn_async: reward_tensor, reward_extra_infos_dict = ray.get(future_reward) batch.batch["token_level_scores"] = reward_tensor if reward_extra_infos_dict: batch.non_tensor_batch.update({k: np.array(v) for k, v in reward_extra_infos_dict.items()}) # compute rewards. apply_kl_penalty if available if self.config.algorithm.use_kl_in_reward: batch, kl_metrics = apply_kl_penalty( batch, kl_ctrl=self.kl_ctrl_in_reward, kl_penalty=self.config.algorithm.kl_penalty ) metrics.update(kl_metrics) else: batch.batch["token_level_rewards"] = batch.batch["token_level_scores"] batch.batch["seq_level_rewards"] = batch.batch["token_level_scores"] beta = self.config.algorithm.sppo_eta batch = compute_advantage(batch, beta=beta) # update critic if self.use_critic: with simple_timer("update_critic", timing_raw): critic_output = self.critic_wg.update_critic(batch) critic_output_metrics = reduce_metrics(critic_output.meta_info["metrics"]) metrics.update(critic_output_metrics) # implement critic warmup if self.config.trainer.critic_warmup <= self.global_steps: # update actor with simple_timer("update_actor", timing_raw): batch.meta_info["multi_turn"] = self.config.actor_rollout_ref.rollout.multi_turn.enable actor_output = self.actor_rollout_wg.update_actor(batch) actor_output_metrics = reduce_metrics(actor_output.meta_info["metrics"]) metrics.update(actor_output_metrics) # Log rollout generations if enabled rollout_data_dir = self.config.trainer.get("rollout_data_dir", None) if rollout_data_dir: with simple_timer("dump_rollout_generations", timing_raw): print(batch.batch.keys()) inputs = self.tokenizer.batch_decode(batch.batch["prompts"], skip_special_tokens=True) outputs = self.tokenizer.batch_decode(batch.batch["responses"], skip_special_tokens=True) scores = batch.batch["token_level_scores"].sum(-1).cpu().tolist() self._dump_generations( inputs=inputs, outputs=outputs, scores=scores, reward_extra_infos_dict=reward_extra_infos_dict, dump_path=rollout_data_dir, ) # validate if ( self.val_reward_fn is not None and self.config.trainer.test_freq > 0 and (is_last_step or self.global_steps % self.config.trainer.test_freq == 0) ): with simple_timer("testing", timing_raw): val_metrics: dict = self._validate() if is_last_step: last_val_metrics = val_metrics metrics.update(val_metrics) if self.config.trainer.save_freq > 0 and ( is_last_step or self.global_steps % self.config.trainer.save_freq == 0 ): with simple_timer("save_checkpoint", timing_raw): self._save_checkpoint() # training metrics metrics.update( { "training/global_step": self.global_steps, "training/epoch": epoch, } ) # TODO: make a canonical logger that supports various backend logger.log(data=metrics, step=self.global_steps) if is_last_step: pprint(f"Final validation metrics: {last_val_metrics}") progress_bar.close() return progress_bar.update(1) self.global_steps += 1