ray_trainer.rst 11.4 KB
Newer Older
jerrrrry's avatar
jerrrrry committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
PPO Ray Trainer
===============

We implement the RayPPOTrainer, which is a trainer runs on the driver
process on a single CPU/GPU node (default is CPU).

The PPORayTrainer include 3 core functions for data preparation,
WorkerGroup initialization and PPO training loop.

Data Preparation
----------------

The ``PPORayTrainer``, as a single process, is responsible for loading a
complete batch of samples (prompts) from the dataset and then dispatch
to different worker_groups running on different GPUs.

To generalize the data loading, we implement the ``RLHFDataset`` class
to load the preprocessed parquet files, apply chat templates to the
prompts, add padding, truncate prompts that exceed max prompt length and
then tokenize.

.. code:: python

   self.train_dataset = RLHFDataset(data_files=self.config.data.train_files,
                                       tokenizer=self.tokenizer,
                                       config=self.config.data)

Then, the dataloader will iterate the dataset under PPO mini batch size.

WorkerGroup Initialization
--------------------------

We first introduce a basic implementation of initializing the
``WorkerGroup`` of the actor model on a given set of GPUs.

.. code:: python

   # max_colocate_count means the number of WorkerGroups (i.e. processes) in each RayResourcePool
   # For FSDP backend, we recommend using max_colocate_count=1 that merge all WorkerGroups into one.
   # For Megatron backend, we recommend using max_colocate_count>1 that can utilize different WorkerGroup for differnt models
   resource_pool = RayResourcePool(process_on_nodes=[config.trainer.n_gpus_per_node] * config.trainer.nnodes,
                                   use_gpu=True,
                                   max_colocate_count=1)
   # define actor rollout cls to be init on remote
   actor_rollout_cls = RayClassWithInitArgs(cls=ActorRolloutWorker)
   # define actor_rollout worker group
   actor_rollout_worker_group = MegatronRayWorkerGroup(resource_pool=resource_pool,
                                                       ray_cls_with_init=actor_rollout_cls,
                                                       default_megatron_kwargs=config.actor_rollout.megatron)

Different WorkerGroups, like ``actor_rollout_worker_group`` ,
``critic_worker_group`` and ``ref_worker_group`` lies on a separate
process in the above implementation.

The driver process can then call the distributed compute function within
the ``actor_rollout_worker_group`` and other roles to construct the RL
training loop.

For models colocated in the same set of GPUs, we further provide a
fine-grain optimization, which merge the ``worker_group`` of different roles
in the same process. This optimization can save the redundant
CUDA/distributed context in different processes.

.. code:: python

   # initialize WorkerGroup
   # NOTE: if you want to use a different resource pool for each role, which can support different parallel size,
   # you should not use `create_colocated_worker_cls`. Instead, directly pass different resource pool to different worker groups.
   # See TODO(url) for more information.
   all_wg = {}
   for resource_pool, class_dict in self.resource_pool_to_cls.items():
       worker_dict_cls = create_colocated_worker_cls(class_dict=class_dict)
       wg_dict = self.ray_worker_group_cls(resource_pool=resource_pool, ray_cls_with_init=worker_dict_cls)
       spawn_wg = wg_dict.spawn(prefix_set=class_dict.keys())
       all_wg.update(spawn_wg)

   if self.use_critic:
       self.critic_wg = all_wg['critic']
       self.critic_wg.init_model()

   if self.use_reference_policy:
       self.ref_policy_wg = all_wg['ref']
       self.ref_policy_wg.init_model()

   if self.use_rm:
       self.rm_wg = all_wg['rm']
       self.rm_wg.init_model()

   # we should create rollout at the end so that vllm can have a better estimation of kv cache memory
   self.actor_rollout_wg = all_wg['actor_rollout']
   self.actor_rollout_wg.init_model()

.. note:: For megatron backend, if we merge the ``worker_groups`` into the same processes, all the roles will utilize the same 3D parallel size. To optimize this, we may need to maintain several 3D process groups for each role in the same distributed context. If you want to use different 3D parallel size for different roles, please follow the similar architecture of the first code block to initialize each role's ``worker_group``


PPO Training Loop
-----------------

We implement the PPO training loop by calling the functions in
worker_group of each role. The input and output data of each function is
a ``DataProto`` object implemented in `protocol.py <https://github.com/volcengine/verl/blob/main/verl/protocol.py>`_. In the training
loop, trainer will dispatch/collect the data to/from different GPUs
following the transfer protocols wrapped in the workers' functions. The
computation of PPO micro batches is processed in ``update_actor`` and
``update_critic`` functions.

To extend to other RLHF algorithms, such as DPO, GRPO, please refer to
:doc:`../advance/dpo_extension`.

.. code:: python

   def fit(self):
       """
       The training loop of PPO.
       The driver process only need to call the compute functions of the worker group through RPC to construct the PPO dataflow.
       The light-weight advantage computation is done on the driver process.
       """
       from verl.utils.tracking import Tracking
       from omegaconf import OmegaConf

       logger = Tracking(project_name=self.config.trainer.project_name,
                           experiment_name=self.config.trainer.experiment_name,
                           default_backend=self.config.trainer.logger,
                           config=OmegaConf.to_container(self.config, resolve=True))

       global_steps = 0

       # perform validation before training
       # currently, we only support validation using the reward_function.
       if self.val_reward_fn is not None:
           val_metrics = self._validate()
           pprint(f'Initial validation metrics: {val_metrics}')

       for epoch in range(self.config.trainer.total_epochs):
           for batch_dict in self.train_dataloader:
               metrics = {}

               batch: DataProto = DataProto.from_single_dict(batch_dict)
               # batch = batch.to('cuda')

               # pop those keys for generation
               gen_batch = batch.pop(batch_keys=['input_ids', 'attention_mask', 'position_ids'])

               # generate a batch
               with Timer(name='gen', logger=None) as timer:
                   gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch)
               metrics['timing/gen'] = timer.last

               batch = batch.union(gen_batch_output)

               if self.use_reference_policy:
                   # compute reference log_prob
                   with Timer(name='ref', logger=None) as timer:
                       ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch)
                       batch = batch.union(ref_log_prob)
                   metrics['timing/ref'] = timer.last

               # compute values
               with Timer(name='values', logger=None) as timer:
                   values = self.critic_wg.compute_values(batch)
                   batch = batch.union(values)
               metrics['timing/values'] = timer.last

               with Timer(name='adv', logger=None) as timer:
                   # compute scores. Support both model and function-based.
                   # We first compute the scores using reward model. Then, we call reward_fn to combine
                   # the results from reward model and rule-based results.
                   if self.use_rm:
                       # we first compute reward model score
                       reward_tensor = self.rm_wg.compute_rm_score(batch)
                       batch = batch.union(reward_tensor)

                   # we combine with rule-based rm
                   reward_tensor = self.reward_fn(batch)
                   batch.batch['token_level_scores'] = reward_tensor

                   # compute rewards. apply_kl_penalty if available
                   batch, kl_metrics = apply_kl_penalty(batch,
                                                           kl_ctrl=self.kl_ctrl_in_reward,
                                                           kl_penalty=self.config.algorithm.kl_penalty)
                   metrics.update(kl_metrics)

                   # compute advantages, executed on the driver process
                   batch = compute_advantage(batch,
                                               self.config.algorithm.gamma,
                                               self.config.algorithm.lam,
                                               adv_estimator=self.config.algorithm.adv_estimator)
               metrics['timing/adv'] = timer.last

               # update critic
               if self.use_critic:
                   with Timer(name='update_critic', logger=None) as timer:
                       critic_output = self.critic_wg.update_critic(batch)
                   metrics['timing/update_critic'] = timer.last
                   critic_output_metrics = reduce_metrics(critic_output.meta_info['metrics'])
                   metrics.update(critic_output_metrics)

               # implement critic warmup
               if self.config.trainer.critic_warmup <= global_steps:
                   # update actor
                   with Timer(name='update_actor', logger=None) as timer:
                       actor_output = self.actor_rollout_wg.update_actor(batch)
                   metrics['timing/update_actor'] = timer.last
                   actor_output_metrics = reduce_metrics(actor_output.meta_info['metrics'])
                   metrics.update(actor_output_metrics)

               # validate
               if self.val_reward_fn is not None and (global_steps + 1) % self.config.trainer.test_freq == 0:
                   with Timer(name='testing', logger=None) as timer:
                       val_metrics: dict = self._validate()
                       val_metrics = {f'val/{key}': val for key, val in val_metrics.items()}
                   metrics['timing/testing'] = timer.last
                   metrics.update(val_metrics)

               # collect metrics
               data_metrics = compute_data_metrics(batch=batch)
               metrics.update(data_metrics)

               # TODO: make a canonical logger that supports various backend
               logger.log(data=metrics, step=global_steps)

               if self.config.trainer.save_freq > 0 and (global_steps + 1) % self.config.trainer.save_freq == 0:
                   actor_local_path = os.path.join(self.config.trainer.default_local_dir, 'actor',
                                                   f'global_step_{global_steps}')
                   actor_remote_path = os.path.join(self.config.trainer.default_hdfs_dir, 'actor')
                   self.actor_rollout_wg.save_checkpoint(actor_local_path, actor_remote_path)

                   if self.use_critic:
                       critic_local_path = os.path.join(self.config.trainer.default_local_dir, 'critic',
                                                           f'global_step_{global_steps}')
                       critic_remote_path = os.path.join(self.config.trainer.default_hdfs_dir, 'critic')
                       self.critic_wg.save_checkpoint(critic_local_path, critic_remote_path)

               global_steps += 1

       # perform validation after training
       if self.val_reward_fn is not None:
           val_metrics = self._validate()
           pprint(f'Final validation metrics: {val_metrics}')