# Copyright 2024 PRIME team and/or its affiliates # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ FSDP PPO Trainer with Ray-based single controller. This trainer supports model-agonistic model initialization with huggingface """ import os import statistics import uuid from copy import deepcopy from pprint import pprint import numpy as np import torch from omegaconf import OmegaConf, open_dict from verl import DataProto from verl.single_controller.ray import RayWorkerGroup from verl.trainer.ppo.ray_trainer import RayPPOTrainer from verl.trainer.ppo.ray_trainer import Role, WorkerType, ResourcePoolManager, reduce_metrics, _timer from verl.trainer.ppo.metric_utils import _compute_response_info from verl.utils.checkpoint.checkpoint_manager import find_latest_ckpt_path from verl.utils.dataset.rl_dataset import RLHFDataset, collate_fn from . import prime_core_algos def compute_advantage(data: DataProto, adv_estimator, config): if adv_estimator == 'rloo': responses = data.batch['responses'] response_length = responses.size(-1) attention_mask = data.batch['attention_mask'] response_mask = attention_mask[:, -response_length:] advantages, returns = prime_core_algos.compute_rloo_advantage_return(data, response_mask, config.actor_rollout_ref.rollout.n, config) data.batch['advantages'] = advantages data.batch['returns'] = returns else: raise NotImplementedError return data def compute_data_metrics(batch, use_critic=True): advantages = batch.batch['advantages'] returns = batch.batch['returns'] max_response_length = batch.batch['responses'].shape[-1] prompt_mask = batch.batch['attention_mask'][:, :-max_response_length].bool() response_mask = batch.batch['attention_mask'][:, -max_response_length:].bool() max_prompt_length = prompt_mask.size(-1) response_info = _compute_response_info(batch) prompt_length = response_info['prompt_length'] response_length = response_info['response_length'] valid_adv = torch.masked_select(advantages, response_mask) valid_returns = torch.masked_select(returns, response_mask) if use_critic: values = batch.batch['values'] valid_values = torch.masked_select(values, response_mask) return_diff_var = torch.var(valid_returns - valid_values) return_var = torch.var(valid_returns) metrics = { # adv 'critic/advantages/mean': torch.mean(valid_adv).detach().item(), 'critic/advantages/max': torch.max(valid_adv).detach().item(), 'critic/advantages/min': torch.min(valid_adv).detach().item(), # returns 'critic/returns/mean': torch.mean(valid_returns).detach().item(), 'critic/returns/max': torch.max(valid_returns).detach().item(), 'critic/returns/min': torch.min(valid_returns).detach().item(), **({ # values 'critic/values/mean': torch.mean(valid_values).detach().item(), 'critic/values/max': torch.max(valid_values).detach().item(), 'critic/values/min': torch.min(valid_values).detach().item(), # vf explained var 'critic/vf_explained_var': (1.0 - return_diff_var / (return_var + 1e-5)).detach().item(), } if use_critic else {}), # response length 'response_length/mean': torch.mean(response_length).detach().item(), 'response_length/max': torch.max(response_length).detach().item(), 'response_length/min': torch.min(response_length).detach().item(), 'response_length/clip_ratio': torch.mean(torch.eq(response_length, max_response_length).float()).detach().item(), # prompt length 'prompt_length/mean': torch.mean(prompt_length).detach().item(), 'prompt_length/max': torch.max(prompt_length).detach().item(), 'prompt_length/min': torch.min(prompt_length).detach().item(), 'prompt_length/clip_ratio': torch.mean(torch.eq(prompt_length, max_prompt_length).float()).detach().item(), } return metrics def compute_timing_metrics(batch, timing_raw): response_info = _compute_response_info(batch) num_prompt_tokens = torch.sum(response_info['prompt_length']).item() num_response_tokens = torch.sum(response_info['response_length']).item() num_overall_tokens = num_prompt_tokens + num_response_tokens num_tokens_of_section = { 'gen': num_response_tokens, **{ name: num_overall_tokens for name in ['ref', 'values', 'adv', 'update_critic', 'update_actor'] }, } return { **{ f'timing_s/{name}': value for name, value in timing_raw.items() }, **{ f'timing_per_token_ms/{name}': timing_raw[name] * 1000 / num_tokens_of_section[name] for name in set(num_tokens_of_section.keys( )) & set(timing_raw.keys()) }, } class RayPRIMETrainer(RayPPOTrainer): """ Note that this trainer runs on the driver process on a single CPU/GPU node. """ # TODO: support each role have individual ray_worker_group_cls, # i.e., support different backend of different role def __init__(self, config, tokenizer, role_worker_mapping: dict[Role, WorkerType], resource_pool_manager: ResourcePoolManager, ray_worker_group_cls: RayWorkerGroup = RayWorkerGroup, reward_fn=None, val_reward_fn=None): # assert torch.cuda.is_available(), 'cuda must be available on driver' super().__init__(config, tokenizer, role_worker_mapping, resource_pool_manager, ray_worker_group_cls, reward_fn, val_reward_fn) self.use_critic = False def _validate_config(self): super()._validate_config() # TODO: Additional config checks can be added here config = self.config def _create_dataloader(self): from torch.utils.data import DataLoader, RandomSampler, SequentialSampler # TODO: we have to make sure the batch size is divisible by the dp size self.train_dataset = RLHFDataset(data_files=self.config.data.train_files, tokenizer=self.tokenizer, config=self.config.data) # use sampler for better ckpt resume if self.config.data.shuffle: train_dataloader_generator = torch.Generator() train_dataloader_generator.manual_seed(self.config.data.get('seed', 1)) sampler = RandomSampler(data_source=self.train_dataset, generator=train_dataloader_generator) else: sampler = SequentialSampler(data_source=self.train_dataset) self.train_dataloader = DataLoader(dataset=self.train_dataset, batch_size=int(self.config.data.train_batch_size * self.config.data.oversample_factor), drop_last=True, collate_fn=collate_fn, sampler=sampler) self.val_dataset = RLHFDataset(data_files=self.config.data.val_files, tokenizer=self.tokenizer, config=self.config.data) self.val_dataloader = DataLoader(dataset=self.val_dataset, batch_size=len(self.val_dataset), shuffle=True, drop_last=True, collate_fn=collate_fn) assert len(self.train_dataloader) >= 1 assert len(self.val_dataloader) >= 1 print(f'Size of train dataloader: {len(self.train_dataloader)}') print(f'Size of val dataloader: {len(self.val_dataloader)}') # inject total_training_steps to actor/critic optim_config. This is hacky. total_training_steps = len(self.train_dataloader) * self.config.trainer.total_epochs if self.config.trainer.total_training_steps is not None: total_training_steps = self.config.trainer.total_training_steps self.total_training_steps = total_training_steps print(f'Total training steps: {self.total_training_steps}') OmegaConf.set_struct(self.config, True) with open_dict(self.config): self.config.actor_rollout_ref.actor.optim.total_training_steps = total_training_steps self.config.critic.optim.total_training_steps = total_training_steps def _save_checkpoint(self): # path: given_path + `/global_step_{global_steps}` + `/actor` local_global_step_folder = os.path.join(self.config.trainer.default_local_dir, f'global_step_{self.global_steps}') print(f'local_global_step_folder: {local_global_step_folder}') actor_local_path = os.path.join(local_global_step_folder, 'actor') actor_remote_path = None if self.config.trainer.default_hdfs_dir is None else os.path.join( self.config.trainer.default_hdfs_dir, f'global_step_{self.global_steps}', 'actor') self.actor_rollout_wg.save_checkpoint(actor_local_path, actor_remote_path, self.global_steps, remove_previous_ckpt=self.config.trainer.remove_previous_ckpt_in_save) if self.use_rm: reward_local_path = os.path.join(local_global_step_folder, 'reward') reward_remote_path = None if self.config.trainer.default_hdfs_dir is None else os.path.join( self.config.trainer.default_hdfs_dir, f'global_step_{self.global_steps}', 'reward') self.rm_wg.save_checkpoint(reward_local_path, reward_remote_path, self.global_steps, remove_previous_ckpt=self.config.trainer.remove_previous_ckpt_in_save) # save dataloader dataloader_local_path = os.path.join(local_global_step_folder, 'data.pt') import dill torch.save(self.train_dataloader, dataloader_local_path, pickle_module=dill) # latest checkpointed iteration tracker (for atomic usage) local_latest_checkpointed_iteration = os.path.join(self.config.trainer.default_local_dir, 'latest_checkpointed_iteration.txt') with open(local_latest_checkpointed_iteration, 'w') as f: f.write(str(self.global_steps)) def _load_checkpoint(self): if self.config.trainer.resume_mode == 'disable': return 0 # load from hdfs if self.config.trainer.default_hdfs_dir is not None: NotImplementedError('load from hdfs is not implemented yet') else: checkpoint_folder = self.config.trainer.default_local_dir # TODO: check path if not os.path.isabs(checkpoint_folder): working_dir = os.getcwd() checkpoint_folder = os.path.join(working_dir, checkpoint_folder) global_step_folder = find_latest_ckpt_path(checkpoint_folder) # None if no latest # find global_step_folder if self.config.trainer.resume_mode == 'auto': if global_step_folder is None: print('Training from scratch') return 0 else: if self.config.trainer.resume_mode == "resume_path": assert isinstance(self.config.trainer.resume_from_path, str), "resume ckpt must be str type" assert 'global_step_' in self.config.trainer.resume_from_path, "resume ckpt must specify the global_steps" global_step_folder = self.config.trainer.resume_from_path if not os.path.isabs(global_step_folder): working_dir = os.getcwd() global_step_folder = os.path.join(working_dir, global_step_folder) print(f'Load from checkpoint folder: {global_step_folder}') # set global step self.global_steps = int(global_step_folder.split('global_step_')[-1]) print(f'Setting global step to {self.global_steps}') print(f'Resuming from {global_step_folder}') actor_path = os.path.join(global_step_folder, 'actor') reward_path = os.path.join(global_step_folder, 'reward') # load actor self.actor_rollout_wg.load_checkpoint(actor_path, del_local_after_load=self.config.trainer.del_local_ckpt_after_load) # load rm if self.use_rm: self.rm_wg.load_checkpoint(reward_path, del_local_after_load=self.config.trainer.del_local_ckpt_after_load) # load dataloader, # TODO: from remote not implemented yet dataloader_local_path = os.path.join(global_step_folder, 'data.pt') self.train_dataloader = torch.load(dataloader_local_path) if isinstance(self.train_dataloader.dataset, RLHFDataset): self.train_dataloader.dataset.resume_dataset_state() def fit(self): """ The training loop of PPO. The driver process only need to call the compute functions of the worker group through RPC to construct the PPO dataflow. The light-weight advantage computation is done on the driver process. """ from verl.utils.tracking import Tracking from omegaconf import OmegaConf logger = Tracking(project_name=self.config.trainer.project_name, experiment_name=self.config.trainer.experiment_name, default_backend=self.config.trainer.logger, config=OmegaConf.to_container(self.config, resolve=True)) self.global_steps = 0 # load checkpoint before doing anything self._load_checkpoint() # perform validation before training # currently, we only support validation using the reward_function. if self.val_reward_fn is not None and self.config.trainer.get('val_before_train', True): val_metrics = self._validate() pprint(f'Initial validation metrics: {val_metrics}') logger.log(data=val_metrics, step=self.global_steps) if self.config.trainer.get('val_only', False): return # we start from step 1 self.global_steps += 1 for epoch in range(self.config.trainer.total_epochs): for batch_dict in self.train_dataloader: metrics = {} timing_raw = {} batch: DataProto = DataProto.from_single_dict(batch_dict) # pop those keys for generation gen_batch = batch.pop(batch_keys=['input_ids', 'attention_mask', 'position_ids']) with _timer('step', timing_raw): # generate a batch with _timer('gen', timing_raw): gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch) if self.config.algorithm.adv_estimator == 'remax': with _timer('gen_max', timing_raw): gen_baseline_batch = deepcopy(gen_batch) gen_baseline_batch.meta_info['do_sample'] = False gen_baseline_output = self.actor_rollout_wg.generate_sequences(gen_baseline_batch) batch = batch.union(gen_baseline_output) reward_baseline_tensor = self.reward_fn(batch) reward_baseline_tensor = reward_baseline_tensor.sum(dim=-1) batch.pop(batch_keys=list(gen_baseline_output.batch.keys())) batch.batch['reward_baselines'] = reward_baseline_tensor del gen_baseline_batch, gen_baseline_output batch.non_tensor_batch['uid'] = np.array([str(uuid.uuid4()) for _ in range(len(batch.batch))], dtype=object) # repeat to align with repeated responses in rollout batch = batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True) batch = batch.union(gen_batch_output) # balance the number of valid tokens on each dp rank. # Note that this breaks the order of data inside the batch. # Please take care when you implement group based adv computation such as GRPO and rloo # self._balance_batch(batch, metrics=metrics) # compute global_valid tokens batch.meta_info['global_token_num'] = torch.sum(batch.batch['attention_mask'], dim=-1).tolist() # verify with _timer('verify', timing_raw): scores = self.reward_fn.verify(batch) metrics['acc'] = statistics.mean(scores) # filter the batch. 1/oversample_factor samples will be kept. If there is a filter, prompts passing it will be prioritized. batch = self.filter_and_downsample(scores, batch) batch.meta_info['n'] = self.config.actor_rollout_ref.rollout.n n_samples = self.config.actor_rollout_ref.rollout.n # recompute old_log_probs with _timer('old_log_prob', timing_raw): old_log_prob = self.actor_rollout_wg.compute_log_prob(batch) batch = batch.union(old_log_prob) if self.use_reference_policy: # compute reference log_prob with _timer('ref', timing_raw): ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch) batch = batch.union(ref_log_prob) with _timer('adv', timing_raw): if self.use_rm: update_style = self.config.reward_model.model.get('update', 'none') if update_style == 'none': # only run forward reward_output = self.rm_wg.compute_rm_score(batch) elif update_style == 'after': # update and directly return the reward reward_output = self.rm_wg.update_rm(batch) elif update_style == 'before': # update reward model, and then run forward reward_output = self.rm_wg.update_rm(batch) if 'metrics' in reward_output.meta_info.keys(): reward_output_metrics = reduce_metrics(reward_output.meta_info['metrics']) metrics.update(reward_output_metrics) reward_output = self.rm_wg.compute_rm_score(batch) elif update_style == 'reverse': # run forward to calculate statistics, then update reward model reward_output = self.rm_wg.compute_rm_score(batch) # broadcast q and acc tensor to each result bc_td = DataProto.from_dict( tensors={ 'Q_bc': reward_output.batch['q'].sum(dim=-1).view(-1, n_samples).unsqueeze( 1).expand(-1, n_samples, -1).reshape(-1, n_samples), 'acc_bc': batch.batch['acc'].view(-1, n_samples).unsqueeze(1).expand( -1, n_samples, -1).reshape(-1, n_samples) }) batch = batch.union(bc_td) reward_output = self.rm_wg.update_rm(batch) else: raise NotImplementedError batch = batch.union(reward_output) if 'metrics' in reward_output.meta_info.keys(): reward_output_metrics = reduce_metrics(reward_output.meta_info['metrics']) metrics.update(reward_output_metrics) # compute advantages, executed on the driver process batch = compute_advantage(batch, adv_estimator=self.config.algorithm.adv_estimator, config=self.config) # update actor with _timer('update_actor', timing_raw): actor_output = self.actor_rollout_wg.update_actor(batch) actor_output_metrics = reduce_metrics(actor_output.meta_info['metrics']) metrics.update(actor_output_metrics) # validate if self.val_reward_fn is not None and self.config.trainer.test_freq > 0 and \ self.global_steps % self.config.trainer.test_freq == 0: with _timer('testing', timing_raw): val_metrics: dict = self._validate() metrics.update(val_metrics) if self.config.trainer.save_freq > 0 and \ self.global_steps % self.config.trainer.save_freq == 0: with _timer('save_checkpoint', timing_raw): self._save_checkpoint() # collect metrics metrics.update(compute_data_metrics(batch=batch, use_critic=self.use_critic)) metrics.update(compute_timing_metrics(batch=batch, timing_raw=timing_raw)) # TODO: make a canonical logger that supports various backend logger.log(data=metrics, step=self.global_steps) self.global_steps += 1 if self.global_steps >= self.total_training_steps: # perform validation after training if self.val_reward_fn is not None: val_metrics = self._validate() pprint(f'Final validation metrics: {val_metrics}') logger.log(data=val_metrics, step=self.global_steps) if self.config.trainer.save_freq > 0 and \ (self.global_steps - 1) % self.config.trainer.save_freq != 0: with _timer('save_checkpoint', timing_raw): self._save_checkpoint() return def filter_and_downsample(self, scores, batch: DataProto): """ downsample the batch according to oversample_factor samples passing the filters will be prioritized """ n_samples = int(self.config.actor_rollout_ref.rollout.n) reward_matrix = torch.tensor(scores).reshape(-1, n_samples) filter_mask = torch.ones((reward_matrix.shape[0]), dtype=torch.bool) if self.config.data.filter_accuracy: acc_tensor = torch.mean(reward_matrix, dim=-1) filter_mask[(acc_tensor > self.config.data.accuracy_upper_bound) | (acc_tensor < self.config.data.accuracy_lower_bound)] = False if self.config.data.filter_truncate: length_matrix = batch.batch['attention_mask'][:, -batch.batch['responses'].shape[-1]:].sum(dim=-1).reshape( -1, n_samples) length_tensor = torch.max(length_matrix, dim=-1)[0] filter_mask[length_tensor >= self.config.data.max_response_length - 1] = False reorder_index = torch.argsort(filter_mask, descending=True) reorder_index = (reorder_index.unsqueeze(-1) * n_samples + torch.arange(0, n_samples).unsqueeze(0)).view(-1) batch.reorder(reorder_index[:int(len(batch) // self.config.data.oversample_factor)]) # this operation is inplace return batch