# Copyright 2024 PRIME team and/or its affiliates # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ FSDP PPO Trainer with Ray-based single controller. This trainer supports model-agonistic model initialization with huggingface """ import os import statistics import uuid from copy import deepcopy from pprint import pprint import numpy as np import torch from omegaconf import OmegaConf, open_dict from verl import DataProto from verl.single_controller.ray import RayWorkerGroup from verl.trainer.ppo.core_algos import agg_loss from verl.trainer.ppo.metric_utils import _compute_response_info from verl.trainer.ppo.ray_trainer import RayPPOTrainer, ResourcePoolManager, Role, WorkerType from verl.utils.checkpoint.checkpoint_manager import find_latest_ckpt_path from verl.utils.dataset.rl_dataset import RLHFDataset, collate_fn from verl.utils.metric import reduce_metrics from verl.utils.profiler.performance import simple_timer from . import prime_core_algos def compute_advantage(data: DataProto, adv_estimator, config): if adv_estimator == "rloo": responses = data.batch["responses"] response_length = responses.size(-1) attention_mask = data.batch["attention_mask"] response_mask = attention_mask[:, -response_length:] advantages, returns = prime_core_algos.compute_rloo_advantage_return( data, response_mask, config.actor_rollout_ref.rollout.n, config ) data.batch["advantages"] = advantages data.batch["returns"] = returns else: raise NotImplementedError return data def compute_data_metrics(batch, use_critic=True): advantages = batch.batch["advantages"] returns = batch.batch["returns"] max_response_length = batch.batch["responses"].shape[-1] prompt_mask = batch.batch["attention_mask"][:, :-max_response_length].bool() response_mask = batch.batch["attention_mask"][:, -max_response_length:].bool() max_prompt_length = prompt_mask.size(-1) response_info = _compute_response_info(batch) prompt_length = response_info["prompt_length"] response_length = response_info["response_length"] valid_adv = torch.masked_select(advantages, response_mask) valid_returns = torch.masked_select(returns, response_mask) if use_critic: values = batch.batch["values"] valid_values = torch.masked_select(values, response_mask) return_diff_var = torch.var(valid_returns - valid_values) return_var = torch.var(valid_returns) metrics = { # adv "critic/advantages/mean": torch.mean(valid_adv).detach().item(), "critic/advantages/max": torch.max(valid_adv).detach().item(), "critic/advantages/min": torch.min(valid_adv).detach().item(), # returns "critic/returns/mean": torch.mean(valid_returns).detach().item(), "critic/returns/max": torch.max(valid_returns).detach().item(), "critic/returns/min": torch.min(valid_returns).detach().item(), **( { # values "critic/values/mean": torch.mean(valid_values).detach().item(), "critic/values/max": torch.max(valid_values).detach().item(), "critic/values/min": torch.min(valid_values).detach().item(), # vf explained var "critic/vf_explained_var": (1.0 - return_diff_var / (return_var + 1e-5)).detach().item(), } if use_critic else {} ), # response length "response_length/mean": torch.mean(response_length).detach().item(), "response_length/max": torch.max(response_length).detach().item(), "response_length/min": torch.min(response_length).detach().item(), "response_length/clip_ratio": torch.mean(torch.eq(response_length, max_response_length).float()) .detach() .item(), # prompt length "prompt_length/mean": torch.mean(prompt_length).detach().item(), "prompt_length/max": torch.max(prompt_length).detach().item(), "prompt_length/min": torch.min(prompt_length).detach().item(), "prompt_length/clip_ratio": torch.mean(torch.eq(prompt_length, max_prompt_length).float()).detach().item(), } return metrics def compute_response_mask(data: DataProto): responses = data.batch["responses"] response_length = responses.size(1) attention_mask = data.batch["attention_mask"] return attention_mask[:, -response_length:] def compute_timing_metrics(batch, timing_raw): response_info = _compute_response_info(batch) num_prompt_tokens = torch.sum(response_info["prompt_length"]).item() num_response_tokens = torch.sum(response_info["response_length"]).item() num_overall_tokens = num_prompt_tokens + num_response_tokens num_tokens_of_section = { "gen": num_response_tokens, **{name: num_overall_tokens for name in ["ref", "values", "adv", "update_critic", "update_actor"]}, } return { **{f"timing_s/{name}": value for name, value in timing_raw.items()}, **{ f"timing_per_token_ms/{name}": timing_raw[name] * 1000 / num_tokens_of_section[name] for name in set(num_tokens_of_section.keys()) & set(timing_raw.keys()) }, } class RayPRIMETrainer(RayPPOTrainer): """ Note that this trainer runs on the driver process on a single CPU/GPU node. """ # TODO: support each role have individual ray_worker_group_cls, # i.e., support different backend of different role def __init__( self, config, tokenizer, role_worker_mapping: dict[Role, WorkerType], resource_pool_manager: ResourcePoolManager, ray_worker_group_cls: RayWorkerGroup = RayWorkerGroup, reward_fn=None, val_reward_fn=None, device_name="cuda", ): # assert get_torch_device().is_available(), 'cuda must be available on driver' super().__init__( config, tokenizer, role_worker_mapping, resource_pool_manager, ray_worker_group_cls, reward_fn=reward_fn, val_reward_fn=val_reward_fn, device_name=device_name, ) self.use_critic = False def _validate_config(self): super()._validate_config() # TODO: Additional config checks can be added here def _create_dataloader(self, *args, **kwargs): from torch.utils.data import DataLoader, RandomSampler, SequentialSampler # TODO: we have to make sure the batch size is divisible by the dp size self.train_dataset = RLHFDataset( data_files=self.config.data.train_files, tokenizer=self.tokenizer, config=self.config.data ) # use sampler for better ckpt resume if self.config.data.shuffle: train_dataloader_generator = torch.Generator() train_dataloader_generator.manual_seed(self.config.data.get("seed", 1)) sampler = RandomSampler(data_source=self.train_dataset, generator=train_dataloader_generator) else: sampler = SequentialSampler(data_source=self.train_dataset) self.train_dataloader = DataLoader( dataset=self.train_dataset, batch_size=int(self.config.data.train_batch_size * self.config.data.oversample_factor), drop_last=True, collate_fn=collate_fn, sampler=sampler, ) self.val_dataset = RLHFDataset( data_files=self.config.data.val_files, tokenizer=self.tokenizer, config=self.config.data ) self.val_dataloader = DataLoader( dataset=self.val_dataset, batch_size=len(self.val_dataset), shuffle=True, drop_last=True, collate_fn=collate_fn, ) assert len(self.train_dataloader) >= 1 assert len(self.val_dataloader) >= 1 print(f"Size of train dataloader: {len(self.train_dataloader)}") print(f"Size of val dataloader: {len(self.val_dataloader)}") # inject total_training_steps to actor/critic optim_config. This is hacky. total_training_steps = len(self.train_dataloader) * self.config.trainer.total_epochs if self.config.trainer.total_training_steps is not None: total_training_steps = self.config.trainer.total_training_steps self.total_training_steps = total_training_steps print(f"Total training steps: {self.total_training_steps}") OmegaConf.set_struct(self.config, True) with open_dict(self.config): self.config.actor_rollout_ref.actor.optim.total_training_steps = total_training_steps self.config.critic.optim.total_training_steps = total_training_steps def _save_checkpoint(self): # path: given_path + `/global_step_{global_steps}` + `/actor` local_global_step_folder = os.path.join( self.config.trainer.default_local_dir, f"global_step_{self.global_steps}" ) print(f"local_global_step_folder: {local_global_step_folder}") actor_local_path = os.path.join(local_global_step_folder, "actor") actor_remote_path = ( None if self.config.trainer.default_hdfs_dir is None else os.path.join(self.config.trainer.default_hdfs_dir, f"global_step_{self.global_steps}", "actor") ) self.actor_rollout_wg.save_checkpoint( actor_local_path, actor_remote_path, self.global_steps, ) if self.use_rm: reward_local_path = os.path.join(local_global_step_folder, "reward") reward_remote_path = ( None if self.config.trainer.default_hdfs_dir is None else os.path.join(self.config.trainer.default_hdfs_dir, f"global_step_{self.global_steps}", "reward") ) self.rm_wg.save_checkpoint( reward_local_path, reward_remote_path, self.global_steps, ) # save dataloader dataloader_local_path = os.path.join(local_global_step_folder, "data.pt") import dill torch.save(self.train_dataloader, dataloader_local_path, pickle_module=dill) # latest checkpointed iteration tracker (for atomic usage) local_latest_checkpointed_iteration = os.path.join( self.config.trainer.default_local_dir, "latest_checkpointed_iteration.txt" ) with open(local_latest_checkpointed_iteration, "w") as f: f.write(str(self.global_steps)) def _load_checkpoint(self): if self.config.trainer.resume_mode == "disable": return 0 # load from hdfs if self.config.trainer.default_hdfs_dir is not None: NotImplementedError("load from hdfs is not implemented yet") else: checkpoint_folder = self.config.trainer.default_local_dir # TODO: check path if not os.path.isabs(checkpoint_folder): working_dir = os.getcwd() checkpoint_folder = os.path.join(working_dir, checkpoint_folder) global_step_folder = find_latest_ckpt_path(checkpoint_folder) # None if no latest # find global_step_folder if self.config.trainer.resume_mode == "auto": if global_step_folder is None: print("Training from scratch") return 0 else: if self.config.trainer.resume_mode == "resume_path": assert isinstance(self.config.trainer.resume_from_path, str), "resume ckpt must be str type" assert "global_step_" in self.config.trainer.resume_from_path, ( "resume ckpt must specify the global_steps" ) global_step_folder = self.config.trainer.resume_from_path if not os.path.isabs(global_step_folder): working_dir = os.getcwd() global_step_folder = os.path.join(working_dir, global_step_folder) print(f"Load from checkpoint folder: {global_step_folder}") # set global step self.global_steps = int(global_step_folder.split("global_step_")[-1]) print(f"Setting global step to {self.global_steps}") print(f"Resuming from {global_step_folder}") actor_path = os.path.join(global_step_folder, "actor") reward_path = os.path.join(global_step_folder, "reward") # load actor self.actor_rollout_wg.load_checkpoint( actor_path, del_local_after_load=self.config.trainer.del_local_ckpt_after_load ) # load rm if self.use_rm: self.rm_wg.load_checkpoint(reward_path, del_local_after_load=self.config.trainer.del_local_ckpt_after_load) # load dataloader, # TODO: from remote not implemented yet dataloader_local_path = os.path.join(global_step_folder, "data.pt") self.train_dataloader = torch.load(dataloader_local_path) if isinstance(self.train_dataloader.dataset, RLHFDataset): self.train_dataloader.dataset.resume_dataset_state() def fit(self): """ The training loop of PPO. The driver process only need to call the compute functions of the worker group through RPC to construct the PPO dataflow. The light-weight advantage computation is done on the driver process. """ from omegaconf import OmegaConf from verl.utils.tracking import Tracking logger = Tracking( project_name=self.config.trainer.project_name, experiment_name=self.config.trainer.experiment_name, default_backend=self.config.trainer.logger, config=OmegaConf.to_container(self.config, resolve=True), ) self.global_steps = 0 # load checkpoint before doing anything self._load_checkpoint() # perform validation before training # currently, we only support validation using the reward_function. if self.val_reward_fn is not None and self.config.trainer.get("val_before_train", True): val_metrics = self._validate() assert val_metrics, f"{val_metrics=}" pprint(f"Initial validation metrics: {val_metrics}") logger.log(data=val_metrics, step=self.global_steps) if self.config.trainer.get("val_only", False): return # we start from step 1 self.global_steps += 1 for epoch in range(self.config.trainer.total_epochs): for batch_dict in self.train_dataloader: metrics = {} timing_raw = {} batch: DataProto = DataProto.from_single_dict(batch_dict) # pop those keys for generation gen_batch = batch.pop(batch_keys=["input_ids", "attention_mask", "position_ids"]) gen_batch = gen_batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True) with simple_timer("step", timing_raw): # generate a batch with simple_timer("gen", timing_raw): gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch) timing_raw.update(gen_batch_output.meta_info["timing"]) gen_batch_output.meta_info.pop("timing", None) if self.config.algorithm.adv_estimator == "remax": with simple_timer("gen_max", timing_raw): gen_baseline_batch = deepcopy(gen_batch) gen_baseline_batch.meta_info["do_sample"] = False gen_baseline_output = self.actor_rollout_wg.generate_sequences(gen_baseline_batch) batch = batch.union(gen_baseline_output) reward_baseline_tensor = self.reward_fn(batch) reward_baseline_tensor = reward_baseline_tensor.sum(dim=-1) batch.pop(batch_keys=list(gen_baseline_output.batch.keys())) batch.batch["reward_baselines"] = reward_baseline_tensor del gen_baseline_batch, gen_baseline_output batch.non_tensor_batch["uid"] = np.array( [str(uuid.uuid4()) for _ in range(len(batch.batch))], dtype=object ) # repeat to align with repeated responses in rollout batch = batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True) batch = batch.union(gen_batch_output) # Balance the number of valid tokens across DP ranks. # NOTE: This usually changes the order of data in the `batch`, # which won't affect the advantage calculation (since it's based on uid), # but might affect the loss calculation (due to the change of mini-batching). # TODO: Decouple the DP balancing and mini-batching. if self.config.trainer.balance_batch: self._balance_batch(batch, metrics=metrics) # compute global_valid tokens batch.meta_info["global_token_num"] = torch.sum(batch.batch["attention_mask"], dim=-1).tolist() # verify with simple_timer("verify", timing_raw): scores = self.reward_fn.verify(batch) metrics["acc"] = statistics.mean(scores) # filter the batch. 1/oversample_factor samples will be kept. # If there is a filter, prompts passing it will be prioritized. batch = self.filter_and_downsample(scores, batch) batch.meta_info["n"] = self.config.actor_rollout_ref.rollout.n n_samples = self.config.actor_rollout_ref.rollout.n # recompute old_log_probs with simple_timer("old_log_prob", timing_raw): old_log_prob = self.actor_rollout_wg.compute_log_prob(batch) entropys = old_log_prob.batch["entropys"] response_masks = compute_response_mask(batch) loss_agg_mode = self.config.actor_rollout_ref.actor.loss_agg_mode entropy_agg = agg_loss(loss_mat=entropys, loss_mask=response_masks, loss_agg_mode=loss_agg_mode) old_log_prob_metrics = {"actor/entropy": entropy_agg.detach().item()} metrics.update(old_log_prob_metrics) old_log_prob.batch.pop("entropys") batch = batch.union(old_log_prob) if self.use_reference_policy: # compute reference log_prob with simple_timer("ref", timing_raw): ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch) batch = batch.union(ref_log_prob) with simple_timer("adv", timing_raw): if self.use_rm: update_style = self.config.reward_model.model.get("update", "none") if update_style == "none": # only run forward reward_output = self.rm_wg.compute_rm_score(batch) elif update_style == "after": # update and directly return the reward reward_output = self.rm_wg.update_rm(batch) elif update_style == "before": # update reward model, and then run forward reward_output = self.rm_wg.update_rm(batch) if "metrics" in reward_output.meta_info.keys(): reward_output_metrics = reduce_metrics(reward_output.meta_info["metrics"]) metrics.update(reward_output_metrics) reward_output = self.rm_wg.compute_rm_score(batch) elif ( update_style == "reverse" ): # run forward to calculate statistics, then update reward model reward_output = self.rm_wg.compute_rm_score(batch) # broadcast q and acc tensor to each result bc_td = DataProto.from_dict( tensors={ "Q_bc": reward_output.batch["q"] .sum(dim=-1) .view(-1, n_samples) .unsqueeze(1) .expand(-1, n_samples, -1) .reshape(-1, n_samples), "acc_bc": batch.batch["acc"] .view(-1, n_samples) .unsqueeze(1) .expand(-1, n_samples, -1) .reshape(-1, n_samples), } ) batch = batch.union(bc_td) reward_output = self.rm_wg.update_rm(batch) else: raise NotImplementedError batch = batch.union(reward_output) if "metrics" in reward_output.meta_info.keys(): reward_output_metrics = reduce_metrics(reward_output.meta_info["metrics"]) metrics.update(reward_output_metrics) # compute advantages, executed on the driver process batch = compute_advantage( batch, adv_estimator=self.config.algorithm.adv_estimator, config=self.config ) # update actor with simple_timer("update_actor", timing_raw): actor_output = self.actor_rollout_wg.update_actor(batch) actor_output_metrics = reduce_metrics(actor_output.meta_info["metrics"]) metrics.update(actor_output_metrics) # validate if ( self.val_reward_fn is not None and self.config.trainer.test_freq > 0 and self.global_steps % self.config.trainer.test_freq == 0 ): with simple_timer("testing", timing_raw): val_metrics: dict = self._validate() metrics.update(val_metrics) if self.config.trainer.save_freq > 0 and self.global_steps % self.config.trainer.save_freq == 0: with simple_timer("save_checkpoint", timing_raw): self._save_checkpoint() # collect metrics metrics.update(compute_data_metrics(batch=batch, use_critic=self.use_critic)) metrics.update(compute_timing_metrics(batch=batch, timing_raw=timing_raw)) # TODO: make a canonical logger that supports various backend logger.log(data=metrics, step=self.global_steps) self.global_steps += 1 if self.global_steps >= self.total_training_steps: # perform validation after training if self.val_reward_fn is not None: val_metrics = self._validate() pprint(f"Final validation metrics: {val_metrics}") logger.log(data=val_metrics, step=self.global_steps) if ( self.config.trainer.save_freq > 0 and (self.global_steps - 1) % self.config.trainer.save_freq != 0 ): with simple_timer("save_checkpoint", timing_raw): self._save_checkpoint() return def filter_and_downsample(self, scores, batch: DataProto): """ downsample the batch according to oversample_factor samples passing the filters will be prioritized """ n_samples = int(self.config.actor_rollout_ref.rollout.n) reward_matrix = torch.tensor(scores).reshape(-1, n_samples) filter_mask = torch.ones((reward_matrix.shape[0]), dtype=torch.bool) if self.config.data.filter_accuracy: acc_tensor = torch.mean(reward_matrix, dim=-1) filter_mask[ (acc_tensor > self.config.data.accuracy_upper_bound) | (acc_tensor < self.config.data.accuracy_lower_bound) ] = False if self.config.data.filter_truncate: length_matrix = ( batch.batch["attention_mask"][:, -batch.batch["responses"].shape[-1] :] .sum(dim=-1) .reshape(-1, n_samples) ) length_tensor = torch.max(length_matrix, dim=-1)[0] filter_mask[length_tensor >= self.config.data.max_response_length - 1] = False reorder_index = torch.argsort(filter_mask, descending=True) reorder_index = (reorder_index.unsqueeze(-1) * n_samples + torch.arange(0, n_samples).unsqueeze(0)).view(-1) batch.reorder( reorder_index[: int(len(batch) // self.config.data.oversample_factor)] ) # this operation is inplace return batch